xinference 0.9.0__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (47) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +33 -0
  3. xinference/client/common.py +2 -0
  4. xinference/client/restful/restful_client.py +49 -17
  5. xinference/conftest.py +4 -1
  6. xinference/core/supervisor.py +11 -1
  7. xinference/core/worker.py +29 -9
  8. xinference/deploy/cmdline.py +73 -2
  9. xinference/deploy/utils.py +25 -1
  10. xinference/device_utils.py +0 -2
  11. xinference/model/core.py +13 -2
  12. xinference/model/image/core.py +16 -2
  13. xinference/model/image/stable_diffusion/core.py +25 -2
  14. xinference/model/llm/__init__.py +17 -0
  15. xinference/model/llm/core.py +18 -2
  16. xinference/model/llm/ggml/llamacpp.py +3 -19
  17. xinference/model/llm/llm_family.json +8 -3
  18. xinference/model/llm/llm_family.py +100 -29
  19. xinference/model/llm/llm_family_modelscope.json +57 -3
  20. xinference/model/llm/pytorch/baichuan.py +2 -0
  21. xinference/model/llm/pytorch/chatglm.py +2 -0
  22. xinference/model/llm/pytorch/core.py +23 -0
  23. xinference/model/llm/pytorch/falcon.py +4 -0
  24. xinference/model/llm/pytorch/internlm2.py +2 -0
  25. xinference/model/llm/pytorch/llama_2.py +4 -0
  26. xinference/model/llm/pytorch/qwen_vl.py +1 -0
  27. xinference/model/llm/pytorch/vicuna.py +2 -0
  28. xinference/model/llm/pytorch/yi_vl.py +1 -0
  29. xinference/types.py +5 -2
  30. xinference/web/ui/build/asset-manifest.json +3 -3
  31. xinference/web/ui/build/index.html +1 -1
  32. xinference/web/ui/build/static/js/{main.87d39ffb.js → main.78829790.js} +3 -3
  33. xinference/web/ui/build/static/js/main.78829790.js.map +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/18e5d5422e2464abf4a3e6d38164570e2e426e0a921e9a2628bbae81b18da353.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/98b7ef307f436affe13d75a4f265b27e828ccc2b10ffae6513abe2681bc11971.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/e8687f75d2adacd34852b71c41ca17203d6fb4c8999ea55325bb2939f9d9ea90.json +1 -0
  37. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/METADATA +7 -5
  38. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/RECORD +43 -43
  39. xinference/web/ui/build/static/js/main.87d39ffb.js.map +0 -1
  40. xinference/web/ui/node_modules/.cache/babel-loader/0738899eefad7f90261125823d87ea9f0d53667b1479a0c1f398aff14f2bbd2a.json +0 -1
  41. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +0 -1
  42. xinference/web/ui/node_modules/.cache/babel-loader/77d4d795f078408fa2dd49da26d1ba1543d51b63cc253e736f4bef2e6014e888.json +0 -1
  43. /xinference/web/ui/build/static/js/{main.87d39ffb.js.LICENSE.txt → main.78829790.js.LICENSE.txt} +0 -0
  44. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/LICENSE +0 -0
  45. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/WHEEL +0 -0
  46. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/entry_points.txt +0 -0
  47. {xinference-0.9.0.dist-info → xinference-0.9.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-02-22T15:40:53+0800",
11
+ "date": "2024-03-08T13:28:03+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "c653c975847f9f6a81382033a9c8f5bd81bf70f2",
15
- "version": "0.9.0"
14
+ "full-revisionid": "29f4c10a854cfec684dcf8398a0974f64bf8ce2b",
15
+ "version": "0.9.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -219,6 +219,11 @@ class RESTfulAPI:
219
219
  self._router.add_api_route(
220
220
  "/v1/models/families", self._get_builtin_families, methods=["GET"]
221
221
  )
222
+ self._router.add_api_route(
223
+ "/v1/models/vllm-supported",
224
+ self.list_vllm_supported_model_families,
225
+ methods=["GET"],
226
+ )
222
227
  self._router.add_api_route(
223
228
  "/v1/cluster/info", self.get_cluster_device_info, methods=["GET"]
224
229
  )
@@ -651,6 +656,9 @@ class RESTfulAPI:
651
656
  replica = payload.get("replica", 1)
652
657
  n_gpu = payload.get("n_gpu", "auto")
653
658
  request_limits = payload.get("request_limits", None)
659
+ peft_model_path = payload.get("peft_model_path", None)
660
+ image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
661
+ image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
654
662
 
655
663
  exclude_keys = {
656
664
  "model_uid",
@@ -662,6 +670,9 @@ class RESTfulAPI:
662
670
  "replica",
663
671
  "n_gpu",
664
672
  "request_limits",
673
+ "peft_model_path",
674
+ "image_lora_load_kwargs",
675
+ "image_lora_fuse_kwargs",
665
676
  }
666
677
 
667
678
  kwargs = {
@@ -686,6 +697,9 @@ class RESTfulAPI:
686
697
  n_gpu=n_gpu,
687
698
  request_limits=request_limits,
688
699
  wait_ready=wait_ready,
700
+ peft_model_path=peft_model_path,
701
+ image_lora_load_kwargs=image_lora_load_kwargs,
702
+ image_lora_fuse_kwargs=image_lora_fuse_kwargs,
689
703
  **kwargs,
690
704
  )
691
705
 
@@ -845,6 +859,7 @@ class RESTfulAPI:
845
859
  }
846
860
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
847
861
 
862
+ # TODO: Decide if this default value override is necessary #1061
848
863
  if body.max_tokens is None:
849
864
  kwargs["max_tokens"] = max_tokens_field.default
850
865
 
@@ -1136,6 +1151,7 @@ class RESTfulAPI:
1136
1151
  }
1137
1152
  kwargs = body.dict(exclude_unset=True, exclude=exclude)
1138
1153
 
1154
+ # TODO: Decide if this default value override is necessary #1061
1139
1155
  if body.max_tokens is None:
1140
1156
  kwargs["max_tokens"] = max_tokens_field.default
1141
1157
 
@@ -1256,6 +1272,7 @@ class RESTfulAPI:
1256
1272
  self.handle_request_limit_error(re)
1257
1273
  async for item in iterator:
1258
1274
  yield item
1275
+ yield "[DONE]"
1259
1276
  except Exception as ex:
1260
1277
  logger.exception("Chat completion stream got an error: %s", ex)
1261
1278
  await self._report_error_event(model_uid, str(ex))
@@ -1348,6 +1365,22 @@ class RESTfulAPI:
1348
1365
  logger.error(e, exc_info=True)
1349
1366
  raise HTTPException(status_code=500, detail=str(e))
1350
1367
 
1368
+ async def list_vllm_supported_model_families(self) -> JSONResponse:
1369
+ try:
1370
+ from ..model.llm.vllm.core import (
1371
+ VLLM_SUPPORTED_CHAT_MODELS,
1372
+ VLLM_SUPPORTED_MODELS,
1373
+ )
1374
+
1375
+ data = {
1376
+ "chat": VLLM_SUPPORTED_CHAT_MODELS,
1377
+ "generate": VLLM_SUPPORTED_MODELS,
1378
+ }
1379
+ return JSONResponse(content=data)
1380
+ except Exception as e:
1381
+ logger.error(e, exc_info=True)
1382
+ raise HTTPException(status_code=500, detail=str(e))
1383
+
1351
1384
  async def get_cluster_device_info(
1352
1385
  self, detailed: bool = Query(False)
1353
1386
  ) -> JSONResponse:
@@ -43,6 +43,8 @@ def streaming_response_iterator(
43
43
  line = line.strip()
44
44
  if line.startswith(b"data:"):
45
45
  json_str = line[len(b"data:") :].strip()
46
+ if json_str == b"[DONE]":
47
+ continue
46
48
  data = json.loads(json_str.decode("utf-8"))
47
49
  error = data.get("error")
48
50
  if error is not None:
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import typing
15
16
  import warnings
16
17
  from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
17
18
 
@@ -47,6 +48,25 @@ def _get_error_string(response: requests.Response) -> str:
47
48
  return "Unknown error"
48
49
 
49
50
 
51
+ @typing.no_type_check
52
+ def handle_system_prompts(
53
+ chat_history: List["ChatCompletionMessage"], system_prompt: Optional[str]
54
+ ) -> List["ChatCompletionMessage"]:
55
+ history_system_prompts = [
56
+ ch["content"] for ch in chat_history if ch["role"] == "system"
57
+ ]
58
+ if system_prompt is not None:
59
+ history_system_prompts.append(system_prompt)
60
+
61
+ # remove all the system prompt in the chat_history
62
+ chat_history = list(filter(lambda x: x["role"] != "system", chat_history))
63
+ # insert all system prompts at the beginning
64
+ chat_history.insert(
65
+ 0, {"role": "system", "content": ". ".join(history_system_prompts)}
66
+ )
67
+ return chat_history
68
+
69
+
50
70
  class RESTfulModelHandle:
51
71
  """
52
72
  A sync model interface (for RESTful client) which provides type hints that makes it much easier to use xinference
@@ -363,15 +383,8 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
363
383
  if chat_history is None:
364
384
  chat_history = []
365
385
 
366
- if chat_history and chat_history[0]["role"] == "system":
367
- if system_prompt is not None:
368
- chat_history[0]["content"] = system_prompt
369
-
370
- else:
371
- if system_prompt is not None:
372
- chat_history.insert(0, {"role": "system", "content": system_prompt})
373
-
374
- chat_history.append({"role": "user", "content": prompt})
386
+ chat_history = handle_system_prompts(chat_history, system_prompt)
387
+ chat_history.append({"role": "user", "content": prompt}) # type: ignore
375
388
 
376
389
  request_body: Dict[str, Any] = {
377
390
  "model": self._model_uid,
@@ -444,14 +457,8 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
444
457
  if chat_history is None:
445
458
  chat_history = []
446
459
 
447
- if chat_history and chat_history[0]["role"] == "system":
448
- if system_prompt is not None:
449
- chat_history[0]["content"] = system_prompt
450
- else:
451
- if system_prompt is not None:
452
- chat_history.insert(0, {"role": "system", "content": system_prompt})
453
-
454
- chat_history.append({"role": "user", "content": prompt})
460
+ chat_history = handle_system_prompts(chat_history, system_prompt)
461
+ chat_history.append({"role": "user", "content": prompt}) # type: ignore
455
462
 
456
463
  request_body: Dict[str, Any] = {
457
464
  "model": self._model_uid,
@@ -676,6 +683,19 @@ class Client:
676
683
  response_data = response.json()
677
684
  self._cluster_authed = bool(response_data["auth"])
678
685
 
686
+ def vllm_models(self) -> Dict[str, Any]:
687
+ url = f"{self.base_url}/v1/models/vllm-supported"
688
+ response = requests.get(url, headers=self._headers)
689
+ if response.status_code != 200:
690
+ raise RuntimeError(
691
+ f"Failed to fetch VLLM models. detail: {response.json()['detail']}"
692
+ )
693
+
694
+ try:
695
+ return response.json()
696
+ except Exception as e:
697
+ raise RuntimeError(f"Error parsing JSON response: {e}")
698
+
679
699
  def login(self, username: str, password: str):
680
700
  if not self._cluster_authed:
681
701
  return
@@ -771,6 +791,9 @@ class Client:
771
791
  replica: int = 1,
772
792
  n_gpu: Optional[Union[int, str]] = "auto",
773
793
  request_limits: Optional[int] = None,
794
+ peft_model_path: Optional[str] = None,
795
+ image_lora_load_kwargs: Optional[Dict] = None,
796
+ image_lora_fuse_kwargs: Optional[Dict] = None,
774
797
  **kwargs,
775
798
  ) -> str:
776
799
  """
@@ -798,6 +821,12 @@ class Client:
798
821
  request_limits: Optional[int]
799
822
  The number of request limits for this model, default is None.
800
823
  ``request_limits=None`` means no limits for this model.
824
+ peft_model_path: Optional[str]
825
+ PEFT (Parameter-Efficient Fine-Tuning) model path.
826
+ image_lora_load_kwargs: Optional[Dict]
827
+ lora load parameters for image model
828
+ image_lora_fuse_kwargs: Optional[Dict]
829
+ lora fuse parameters for image model
801
830
  **kwargs:
802
831
  Any other parameters been specified.
803
832
 
@@ -820,6 +849,9 @@ class Client:
820
849
  "replica": replica,
821
850
  "n_gpu": n_gpu,
822
851
  "request_limits": request_limits,
852
+ "peft_model_path": peft_model_path,
853
+ "image_lora_load_kwargs": image_lora_load_kwargs,
854
+ "image_lora_fuse_kwargs": image_lora_fuse_kwargs,
823
855
  }
824
856
 
825
857
  for key, value in kwargs.items():
xinference/conftest.py CHANGED
@@ -25,6 +25,10 @@ from typing import Dict, Optional
25
25
  import pytest
26
26
  import xoscar as xo
27
27
 
28
+ # skip health checking for CI
29
+ if os.environ.get("GITHUB_ACTIONS"):
30
+ os.environ["XINFERENCE_DISABLE_HEALTH_CHECK"] = "1"
31
+
28
32
  from .api.oauth2.types import AuthConfig, AuthStartupConfig, User
29
33
  from .constants import XINFERENCE_LOG_BACKUP_COUNT, XINFERENCE_LOG_MAX_BYTES
30
34
  from .core.supervisor import SupervisorActor
@@ -134,7 +138,6 @@ async def _start_test_cluster(
134
138
  logging_conf: Optional[Dict] = None,
135
139
  ):
136
140
  logging.config.dictConfig(logging_conf) # type: ignore
137
-
138
141
  pool = None
139
142
  try:
140
143
  pool = await create_worker_actor_pool(
@@ -714,6 +714,9 @@ class SupervisorActor(xo.StatelessActor):
714
714
  request_limits: Optional[int] = None,
715
715
  wait_ready: bool = True,
716
716
  model_version: Optional[str] = None,
717
+ peft_model_path: Optional[str] = None,
718
+ image_lora_load_kwargs: Optional[Dict] = None,
719
+ image_lora_fuse_kwargs: Optional[Dict] = None,
717
720
  **kwargs,
718
721
  ) -> str:
719
722
  if model_uid is None:
@@ -751,6 +754,9 @@ class SupervisorActor(xo.StatelessActor):
751
754
  model_type=model_type,
752
755
  n_gpu=n_gpu,
753
756
  request_limits=request_limits,
757
+ peft_model_path=peft_model_path,
758
+ image_lora_load_kwargs=image_lora_load_kwargs,
759
+ image_lora_fuse_kwargs=image_lora_fuse_kwargs,
754
760
  **kwargs,
755
761
  )
756
762
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
@@ -922,7 +928,11 @@ class SupervisorActor(xo.StatelessActor):
922
928
  workers = list(self._worker_address_to_worker.values())
923
929
  for worker in workers:
924
930
  ret.update(await worker.list_models())
925
- return {parse_replica_model_uid(k)[0]: v for k, v in ret.items()}
931
+ running_model_info = {parse_replica_model_uid(k)[0]: v for k, v in ret.items()}
932
+ # add replica count
933
+ for k, v in running_model_info.items():
934
+ v["replica"] = self._model_uid_to_replica_info[k].replica
935
+ return running_model_info
926
936
 
927
937
  def is_local_deployment(self) -> bool:
928
938
  # TODO: temporary.
xinference/core/worker.py CHANGED
@@ -27,7 +27,11 @@ import xoscar as xo
27
27
  from async_timeout import timeout
28
28
  from xoscar import MainActorPoolType
29
29
 
30
- from ..constants import XINFERENCE_CACHE_DIR
30
+ from ..constants import (
31
+ XINFERENCE_CACHE_DIR,
32
+ XINFERENCE_DISABLE_HEALTH_CHECK,
33
+ XINFERENCE_HEALTH_CHECK_INTERVAL,
34
+ )
31
35
  from ..core import ModelActor
32
36
  from ..core.status_guard import LaunchStatus
33
37
  from ..device_utils import gpu_count
@@ -40,7 +44,6 @@ from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
40
44
  logger = getLogger(__name__)
41
45
 
42
46
 
43
- DEFAULT_NODE_HEARTBEAT_INTERVAL = 5
44
47
  MODEL_ACTOR_AUTO_RECOVER_LIMIT: Optional[int]
45
48
  _MODEL_ACTOR_AUTO_RECOVER_LIMIT = os.getenv("XINFERENCE_MODEL_ACTOR_AUTO_RECOVER_LIMIT")
46
49
  if _MODEL_ACTOR_AUTO_RECOVER_LIMIT is not None:
@@ -177,12 +180,13 @@ class WorkerActor(xo.StatelessActor):
177
180
  address=self._supervisor_address, uid=SupervisorActor.uid()
178
181
  )
179
182
  await self._supervisor_ref.add_worker(self.address)
180
- # Run _periodical_report_status() in a dedicated thread.
181
- self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
182
- self._isolation.start()
183
- asyncio.run_coroutine_threadsafe(
184
- self._periodical_report_status(), loop=self._isolation.loop
185
- )
183
+ if not XINFERENCE_DISABLE_HEALTH_CHECK:
184
+ # Run _periodical_report_status() in a dedicated thread.
185
+ self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
186
+ self._isolation.start()
187
+ asyncio.run_coroutine_threadsafe(
188
+ self._periodical_report_status(), loop=self._isolation.loop
189
+ )
186
190
  logger.info(f"Xinference worker {self.address} started")
187
191
  logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
188
192
  purge_dir(XINFERENCE_CACHE_DIR)
@@ -487,6 +491,9 @@ class WorkerActor(xo.StatelessActor):
487
491
  quantization: Optional[str],
488
492
  model_type: str = "LLM",
489
493
  n_gpu: Optional[Union[int, str]] = "auto",
494
+ peft_model_path: Optional[str] = None,
495
+ image_lora_load_kwargs: Optional[Dict] = None,
496
+ image_lora_fuse_kwargs: Optional[Dict] = None,
490
497
  request_limits: Optional[int] = None,
491
498
  **kwargs,
492
499
  ):
@@ -512,6 +519,16 @@ class WorkerActor(xo.StatelessActor):
512
519
  if isinstance(n_gpu, str) and n_gpu != "auto":
513
520
  raise ValueError("Currently `n_gpu` only supports `auto`.")
514
521
 
522
+ if peft_model_path is not None:
523
+ if model_type in ("embedding", "rerank"):
524
+ raise ValueError(
525
+ f"PEFT adaptors cannot be applied to embedding or rerank models."
526
+ )
527
+ if model_type == "LLM" and model_format in ("ggufv2", "ggmlv3"):
528
+ raise ValueError(
529
+ f"PEFT adaptors can only be applied to pytorch-like models"
530
+ )
531
+
515
532
  assert model_uid not in self._model_uid_to_model
516
533
  self._check_model_is_valid(model_name, model_format)
517
534
  assert self._supervisor_ref is not None
@@ -533,6 +550,9 @@ class WorkerActor(xo.StatelessActor):
533
550
  model_format,
534
551
  model_size_in_billions,
535
552
  quantization,
553
+ peft_model_path,
554
+ image_lora_load_kwargs,
555
+ image_lora_fuse_kwargs,
536
556
  is_local_deployment,
537
557
  **kwargs,
538
558
  )
@@ -662,7 +682,7 @@ class WorkerActor(xo.StatelessActor):
662
682
  ) as ex: # pragma: no cover # noqa: E722 # nosec # pylint: disable=bare-except
663
683
  logger.error(f"Failed to upload node info: {ex}")
664
684
  try:
665
- await asyncio.sleep(DEFAULT_NODE_HEARTBEAT_INTERVAL)
685
+ await asyncio.sleep(XINFERENCE_HEALTH_CHECK_INTERVAL)
666
686
  except asyncio.CancelledError: # pragma: no cover
667
687
  break
668
688
 
@@ -17,7 +17,7 @@ import logging
17
17
  import os
18
18
  import sys
19
19
  import warnings
20
- from typing import List, Optional, Union
20
+ from typing import List, Optional, Tuple, Union
21
21
 
22
22
  import click
23
23
  from xoscar.utils import get_next_port
@@ -40,7 +40,12 @@ from ..constants import (
40
40
  )
41
41
  from ..isolation import Isolation
42
42
  from ..types import ChatCompletionMessage
43
- from .utils import get_config_dict, get_log_file, get_timestamp_ms
43
+ from .utils import (
44
+ get_config_dict,
45
+ get_log_file,
46
+ get_timestamp_ms,
47
+ handle_click_args_type,
48
+ )
44
49
 
45
50
  try:
46
51
  # provide elaborate line editing and history features.
@@ -525,6 +530,10 @@ def list_model_registrations(
525
530
  @cli.command(
526
531
  "launch",
527
532
  help="Launch a model with the Xinference framework with the given parameters.",
533
+ context_settings=dict(
534
+ ignore_unknown_options=True,
535
+ allow_extra_args=True,
536
+ ),
528
537
  )
529
538
  @click.option(
530
539
  "--endpoint",
@@ -587,13 +596,35 @@ def list_model_registrations(
587
596
  type=str,
588
597
  help='The number of GPUs used by the model, default is "auto".',
589
598
  )
599
+ @click.option(
600
+ "--peft-model-path",
601
+ default=None,
602
+ type=str,
603
+ help="PEFT model path.",
604
+ )
605
+ @click.option(
606
+ "--image-lora-load-kwargs",
607
+ "-ld",
608
+ "image_lora_load_kwargs",
609
+ type=(str, str),
610
+ multiple=True,
611
+ )
612
+ @click.option(
613
+ "--image-lora-fuse-kwargs",
614
+ "-fd",
615
+ "image_lora_fuse_kwargs",
616
+ type=(str, str),
617
+ multiple=True,
618
+ )
590
619
  @click.option(
591
620
  "--trust-remote-code",
592
621
  default=True,
593
622
  type=bool,
594
623
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
595
624
  )
625
+ @click.pass_context
596
626
  def model_launch(
627
+ ctx,
597
628
  endpoint: Optional[str],
598
629
  model_name: str,
599
630
  model_type: str,
@@ -603,8 +634,18 @@ def model_launch(
603
634
  quantization: str,
604
635
  replica: int,
605
636
  n_gpu: str,
637
+ peft_model_path: Optional[str],
638
+ image_lora_load_kwargs: Optional[Tuple],
639
+ image_lora_fuse_kwargs: Optional[Tuple],
606
640
  trust_remote_code: bool,
607
641
  ):
642
+ kwargs = {}
643
+ for i in range(0, len(ctx.args), 2):
644
+ if not ctx.args[i].startswith("--"):
645
+ raise ValueError("You must specify extra kwargs with `--` prefix.")
646
+ kwargs[ctx.args[i][2:]] = handle_click_args_type(ctx.args[i + 1])
647
+ print(f"Launch model name: {model_name} with kwargs: {kwargs}", file=sys.stderr)
648
+
608
649
  if n_gpu.lower() == "none":
609
650
  _n_gpu: Optional[Union[int, str]] = None
610
651
  elif n_gpu == "auto":
@@ -612,6 +653,17 @@ def model_launch(
612
653
  else:
613
654
  _n_gpu = int(n_gpu)
614
655
 
656
+ image_lora_load_params = (
657
+ {k: handle_click_args_type(v) for k, v in dict(image_lora_load_kwargs).items()}
658
+ if image_lora_load_kwargs
659
+ else None
660
+ )
661
+ image_lora_fuse_params = (
662
+ {k: handle_click_args_type(v) for k, v in dict(image_lora_fuse_kwargs).items()}
663
+ if image_lora_fuse_kwargs
664
+ else None
665
+ )
666
+
615
667
  endpoint = get_endpoint(endpoint)
616
668
  model_size: Optional[Union[str, int]] = (
617
669
  size_in_billions
@@ -630,7 +682,11 @@ def model_launch(
630
682
  quantization=quantization,
631
683
  replica=replica,
632
684
  n_gpu=_n_gpu,
685
+ peft_model_path=peft_model_path,
686
+ image_lora_load_kwargs=image_lora_load_params,
687
+ image_lora_fuse_kwargs=image_lora_fuse_params,
633
688
  trust_remote_code=trust_remote_code,
689
+ **kwargs,
634
690
  )
635
691
 
636
692
  print(f"Model uid: {model_uid}", file=sys.stderr)
@@ -925,6 +981,21 @@ def model_chat(
925
981
  )
926
982
 
927
983
 
984
+ @cli.command("vllm-models", help="Query and display models compatible with VLLM.")
985
+ @click.option("--endpoint", "-e", type=str, help="Xinference endpoint.")
986
+ def vllm_models(endpoint: Optional[str]):
987
+ endpoint = get_endpoint(endpoint)
988
+ client = RESTfulClient(base_url=endpoint)
989
+ client._set_token(get_stored_token(endpoint, client))
990
+ vllm_models_dict = client.vllm_models()
991
+ print("VLLM supported model families:")
992
+ chat_models = vllm_models_dict["chat"]
993
+ supported_models = vllm_models_dict["generate"]
994
+
995
+ print("VLLM supported chat model families:", chat_models)
996
+ print("VLLM supported generate model families:", supported_models)
997
+
998
+
928
999
  @cli.command("login", help="Login when the cluster is authenticated.")
929
1000
  @click.option("--endpoint", "-e", type=str, help="Xinference endpoint.")
930
1001
  @click.option("--username", type=str, required=True, help="Username.")
@@ -15,7 +15,8 @@
15
15
  import logging
16
16
  import os
17
17
  import time
18
- from typing import TYPE_CHECKING, Optional
18
+ import typing
19
+ from typing import TYPE_CHECKING, Any, Optional
19
20
 
20
21
  import xoscar as xo
21
22
 
@@ -159,3 +160,26 @@ def health_check(address: str, max_attempts: int, sleep_interval: int = 3) -> bo
159
160
  def get_timestamp_ms():
160
161
  t = time.time()
161
162
  return int(round(t * 1000))
163
+
164
+
165
+ @typing.no_type_check
166
+ def handle_click_args_type(arg: str) -> Any:
167
+ if arg == "None":
168
+ return None
169
+ if arg in ("True", "true"):
170
+ return True
171
+ if arg in ("False", "false"):
172
+ return False
173
+ try:
174
+ result = int(arg)
175
+ return result
176
+ except:
177
+ pass
178
+
179
+ try:
180
+ result = float(arg)
181
+ return result
182
+ except:
183
+ pass
184
+
185
+ return arg
@@ -92,8 +92,6 @@ def gpu_count():
92
92
  )
93
93
 
94
94
  return min(torch.cuda.device_count(), len(cuda_visible_devices))
95
- elif torch.backends.mps.is_available():
96
- return 1
97
95
  elif is_xpu_available():
98
96
  return torch.xpu.device_count()
99
97
  else:
xinference/model/core.py CHANGED
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, List, Optional, Tuple
16
+ from typing import Any, Dict, List, Optional, Tuple
17
17
 
18
18
  from .._compat import BaseModel
19
19
 
@@ -52,6 +52,9 @@ def create_model_instance(
52
52
  model_format: Optional[str] = None,
53
53
  model_size_in_billions: Optional[int] = None,
54
54
  quantization: Optional[str] = None,
55
+ peft_model_path: Optional[str] = None,
56
+ image_lora_load_kwargs: Optional[Dict] = None,
57
+ image_lora_fuse_kwargs: Optional[Dict] = None,
55
58
  is_local_deployment: bool = False,
56
59
  **kwargs,
57
60
  ) -> Tuple[Any, ModelDescription]:
@@ -70,6 +73,7 @@ def create_model_instance(
70
73
  model_format,
71
74
  model_size_in_billions,
72
75
  quantization,
76
+ peft_model_path,
73
77
  is_local_deployment,
74
78
  **kwargs,
75
79
  )
@@ -82,7 +86,14 @@ def create_model_instance(
82
86
  elif model_type == "image":
83
87
  kwargs.pop("trust_remote_code", None)
84
88
  return create_image_model_instance(
85
- subpool_addr, devices, model_uid, model_name, **kwargs
89
+ subpool_addr,
90
+ devices,
91
+ model_uid,
92
+ model_name,
93
+ lora_model_path=peft_model_path,
94
+ lora_load_kwargs=image_lora_load_kwargs,
95
+ lora_fuse_kwargs=image_lora_fuse_kwargs,
96
+ **kwargs,
86
97
  )
87
98
  elif model_type == "rerank":
88
99
  kwargs.pop("trust_remote_code", None)
@@ -155,7 +155,14 @@ def get_cache_status(
155
155
 
156
156
 
157
157
  def create_image_model_instance(
158
- subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
158
+ subpool_addr: str,
159
+ devices: List[str],
160
+ model_uid: str,
161
+ model_name: str,
162
+ lora_model_path: Optional[str] = None,
163
+ lora_load_kwargs: Optional[Dict] = None,
164
+ lora_fuse_kwargs: Optional[Dict] = None,
165
+ **kwargs,
159
166
  ) -> Tuple[DiffusionModel, ImageModelDescription]:
160
167
  model_spec = match_diffusion(model_name)
161
168
  controlnet = kwargs.get("controlnet")
@@ -187,7 +194,14 @@ def create_image_model_instance(
187
194
  else:
188
195
  kwargs["controlnet"] = controlnet_model_paths
189
196
  model_path = cache(model_spec)
190
- model = DiffusionModel(model_uid, model_path, **kwargs)
197
+ model = DiffusionModel(
198
+ model_uid,
199
+ model_path,
200
+ lora_model_path=lora_model_path,
201
+ lora_load_kwargs=lora_load_kwargs,
202
+ lora_fuse_kwargs=lora_fuse_kwargs,
203
+ **kwargs,
204
+ )
191
205
  model_description = ImageModelDescription(
192
206
  subpool_addr, devices, model_spec, model_path=model_path
193
207
  )
@@ -21,7 +21,7 @@ import uuid
21
21
  from concurrent.futures import ThreadPoolExecutor
22
22
  from functools import partial
23
23
  from io import BytesIO
24
- from typing import List, Optional, Union
24
+ from typing import Dict, List, Optional, Union
25
25
 
26
26
  from ....constants import XINFERENCE_IMAGE_DIR
27
27
  from ....device_utils import move_model_to_available_device
@@ -32,14 +32,36 @@ logger = logging.getLogger(__name__)
32
32
 
33
33
  class DiffusionModel:
34
34
  def __init__(
35
- self, model_uid: str, model_path: str, device: Optional[str] = None, **kwargs
35
+ self,
36
+ model_uid: str,
37
+ model_path: str,
38
+ device: Optional[str] = None,
39
+ lora_model_path: Optional[str] = None,
40
+ lora_load_kwargs: Optional[Dict] = None,
41
+ lora_fuse_kwargs: Optional[Dict] = None,
42
+ **kwargs,
36
43
  ):
37
44
  self._model_uid = model_uid
38
45
  self._model_path = model_path
39
46
  self._device = device
40
47
  self._model = None
48
+ self._lora_model_path = lora_model_path
49
+ self._lora_load_kwargs = lora_load_kwargs or {}
50
+ self._lora_fuse_kwargs = lora_fuse_kwargs or {}
41
51
  self._kwargs = kwargs
42
52
 
53
+ def _apply_lora(self):
54
+ if self._lora_model_path is not None:
55
+ logger.info(
56
+ f"Loading the LoRA with load kwargs: {self._lora_load_kwargs}, fuse kwargs: {self._lora_fuse_kwargs}."
57
+ )
58
+ assert self._model is not None
59
+ self._model.load_lora_weights(
60
+ self._lora_model_path, **self._lora_load_kwargs
61
+ )
62
+ self._model.fuse_lora(**self._lora_fuse_kwargs)
63
+ logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
64
+
43
65
  def load(self):
44
66
  # import torch
45
67
  from diffusers import AutoPipelineForText2Image
@@ -61,6 +83,7 @@ class DiffusionModel:
61
83
  self._model = move_model_to_available_device(self._model)
62
84
  # Recommended if your computer has < 64 GB of RAM
63
85
  self._model.enable_attention_slicing()
86
+ self._apply_lora()
64
87
 
65
88
  def _call_model(
66
89
  self,