xinference 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (55) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +9 -9
  3. xinference/client/restful/restful_client.py +32 -16
  4. xinference/core/supervisor.py +32 -9
  5. xinference/core/worker.py +13 -8
  6. xinference/deploy/cmdline.py +22 -9
  7. xinference/model/audio/__init__.py +40 -1
  8. xinference/model/audio/core.py +25 -45
  9. xinference/model/audio/custom.py +148 -0
  10. xinference/model/core.py +6 -9
  11. xinference/model/embedding/model_spec.json +24 -0
  12. xinference/model/embedding/model_spec_modelscope.json +24 -0
  13. xinference/model/image/core.py +12 -4
  14. xinference/model/image/stable_diffusion/core.py +8 -7
  15. xinference/model/llm/core.py +9 -14
  16. xinference/model/llm/llm_family.json +263 -0
  17. xinference/model/llm/llm_family.py +26 -4
  18. xinference/model/llm/llm_family_modelscope.json +160 -0
  19. xinference/model/llm/pytorch/baichuan.py +4 -3
  20. xinference/model/llm/pytorch/chatglm.py +3 -2
  21. xinference/model/llm/pytorch/core.py +15 -13
  22. xinference/model/llm/pytorch/falcon.py +6 -5
  23. xinference/model/llm/pytorch/internlm2.py +3 -2
  24. xinference/model/llm/pytorch/llama_2.py +6 -5
  25. xinference/model/llm/pytorch/vicuna.py +4 -3
  26. xinference/model/llm/vllm/core.py +3 -0
  27. xinference/model/rerank/core.py +23 -12
  28. xinference/model/rerank/model_spec.json +24 -0
  29. xinference/model/rerank/model_spec_modelscope.json +25 -1
  30. xinference/model/utils.py +12 -1
  31. xinference/types.py +55 -0
  32. xinference/utils.py +1 -0
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/main.26fdbfbe.js +3 -0
  36. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/f9290c0738db50065492ceedc6a4af25083fe18399b7c44d942273349ad9e643.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +1 -0
  43. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/METADATA +4 -1
  44. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/RECORD +49 -46
  45. xinference/web/ui/build/static/js/main.76ef2b17.js +0 -3
  46. xinference/web/ui/build/static/js/main.76ef2b17.js.map +0 -1
  47. xinference/web/ui/node_modules/.cache/babel-loader/35d0e4a317e5582cbb79d901302e9d706520ac53f8a734c2fd8bfde6eb5a4f02.json +0 -1
  48. xinference/web/ui/node_modules/.cache/babel-loader/d076fd56cf3b15ed2433e3744b98c6b4e4410a19903d1db4de5bba0e1a1b3347.json +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/daad8131d91134f6d7aef895a0c9c32e1cb928277cb5aa66c01028126d215be0.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/f16aec63602a77bd561d0e67fa00b76469ac54b8033754bba114ec5eb3257964.json +0 -1
  51. /xinference/web/ui/build/static/js/{main.76ef2b17.js.LICENSE.txt → main.26fdbfbe.js.LICENSE.txt} +0 -0
  52. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/LICENSE +0 -0
  53. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/WHEEL +0 -0
  54. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/entry_points.txt +0 -0
  55. {xinference-0.10.1.dist-info → xinference-0.10.2.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-04-11T15:35:46+0800",
11
+ "date": "2024-04-19T11:39:12+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e3a947ebddfc53b5e8ec723c1f632c2b895edef1",
15
- "version": "0.10.1"
14
+ "full-revisionid": "f19e85be09bce966e0c0b3e01bc5690eb6016398",
15
+ "version": "0.10.2"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -64,6 +64,7 @@ from ..types import (
64
64
  CreateChatCompletion,
65
65
  CreateCompletion,
66
66
  ImageList,
67
+ PeftModelConfig,
67
68
  max_tokens_field,
68
69
  )
69
70
  from .oauth2.auth_service import AuthService
@@ -692,9 +693,7 @@ class RESTfulAPI:
692
693
  replica = payload.get("replica", 1)
693
694
  n_gpu = payload.get("n_gpu", "auto")
694
695
  request_limits = payload.get("request_limits", None)
695
- peft_model_path = payload.get("peft_model_path", None)
696
- image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
697
- image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
696
+ peft_model_config = payload.get("peft_model_config", None)
698
697
  worker_ip = payload.get("worker_ip", None)
699
698
  gpu_idx = payload.get("gpu_idx", None)
700
699
 
@@ -708,9 +707,7 @@ class RESTfulAPI:
708
707
  "replica",
709
708
  "n_gpu",
710
709
  "request_limits",
711
- "peft_model_path",
712
- "image_lora_load_kwargs",
713
- "image_lora_fuse_kwargs",
710
+ "peft_model_config",
714
711
  "worker_ip",
715
712
  "gpu_idx",
716
713
  }
@@ -725,6 +722,11 @@ class RESTfulAPI:
725
722
  detail="Invalid input. Please specify the model name",
726
723
  )
727
724
 
725
+ if peft_model_config is not None:
726
+ peft_model_config = PeftModelConfig.from_dict(peft_model_config)
727
+ else:
728
+ peft_model_config = None
729
+
728
730
  try:
729
731
  model_uid = await (await self._get_supervisor_ref()).launch_builtin_model(
730
732
  model_uid=model_uid,
@@ -737,9 +739,7 @@ class RESTfulAPI:
737
739
  n_gpu=n_gpu,
738
740
  request_limits=request_limits,
739
741
  wait_ready=wait_ready,
740
- peft_model_path=peft_model_path,
741
- image_lora_load_kwargs=image_lora_load_kwargs,
742
- image_lora_fuse_kwargs=image_lora_fuse_kwargs,
742
+ peft_model_config=peft_model_config,
743
743
  worker_ip=worker_ip,
744
744
  gpu_idx=gpu_idx,
745
745
  **kwargs,
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
18
18
 
19
19
  import requests
20
20
 
21
+ from ...model.utils import convert_float_to_int_or_str
22
+ from ...types import LoRA, PeftModelConfig
21
23
  from ..common import streaming_response_iterator
22
24
 
23
25
  if TYPE_CHECKING:
@@ -746,7 +748,7 @@ class Client:
746
748
  def launch_speculative_llm(
747
749
  self,
748
750
  model_name: str,
749
- model_size_in_billions: Optional[int],
751
+ model_size_in_billions: Optional[Union[int, str, float]],
750
752
  quantization: Optional[str],
751
753
  draft_model_name: str,
752
754
  draft_model_size_in_billions: Optional[int],
@@ -767,6 +769,10 @@ class Client:
767
769
  "`launch_speculative_llm` is an experimental feature and the API may change in the future."
768
770
  )
769
771
 
772
+ # convert float to int or string since the RESTful API does not accept float.
773
+ if isinstance(model_size_in_billions, float):
774
+ model_size_in_billions = convert_float_to_int_or_str(model_size_in_billions)
775
+
770
776
  payload = {
771
777
  "model_uid": None,
772
778
  "model_name": model_name,
@@ -794,15 +800,13 @@ class Client:
794
800
  model_name: str,
795
801
  model_type: str = "LLM",
796
802
  model_uid: Optional[str] = None,
797
- model_size_in_billions: Optional[Union[int, str]] = None,
803
+ model_size_in_billions: Optional[Union[int, str, float]] = None,
798
804
  model_format: Optional[str] = None,
799
805
  quantization: Optional[str] = None,
800
806
  replica: int = 1,
801
807
  n_gpu: Optional[Union[int, str]] = "auto",
808
+ peft_model_config: Optional[Dict] = None,
802
809
  request_limits: Optional[int] = None,
803
- peft_model_path: Optional[str] = None,
804
- image_lora_load_kwargs: Optional[Dict] = None,
805
- image_lora_fuse_kwargs: Optional[Dict] = None,
806
810
  worker_ip: Optional[str] = None,
807
811
  gpu_idx: Optional[Union[int, List[int]]] = None,
808
812
  **kwargs,
@@ -818,7 +822,7 @@ class Client:
818
822
  type of model.
819
823
  model_uid: str
820
824
  UID of model, auto generate a UUID if is None.
821
- model_size_in_billions: Optional[int]
825
+ model_size_in_billions: Optional[Union[int, str, float]]
822
826
  The size (in billions) of the model.
823
827
  model_format: Optional[str]
824
828
  The format of the model.
@@ -829,15 +833,13 @@ class Client:
829
833
  n_gpu: Optional[Union[int, str]],
830
834
  The number of GPUs used by the model, default is "auto".
831
835
  ``n_gpu=None`` means cpu only, ``n_gpu=auto`` lets the system automatically determine the best number of GPUs to use.
836
+ peft_model_config: Optional[Dict]
837
+ - "lora_list": A List of PEFT (Parameter-Efficient Fine-Tuning) model and path.
838
+ - "image_lora_load_kwargs": A Dict of lora load parameters for image model
839
+ - "image_lora_fuse_kwargs": A Dict of lora fuse parameters for image model
832
840
  request_limits: Optional[int]
833
- The number of request limits for this model default is None.
841
+ The number of request limits for this model, default is None.
834
842
  ``request_limits=None`` means no limits for this model.
835
- peft_model_path: Optional[str]
836
- PEFT (Parameter-Efficient Fine-Tuning) model path.
837
- image_lora_load_kwargs: Optional[Dict]
838
- lora load parameters for image model
839
- image_lora_fuse_kwargs: Optional[Dict]
840
- lora fuse parameters for image model
841
843
  worker_ip: Optional[str]
842
844
  Specify the worker ip where the model is located in a distributed scenario.
843
845
  gpu_idx: Optional[Union[int, List[int]]]
@@ -854,9 +856,26 @@ class Client:
854
856
 
855
857
  url = f"{self.base_url}/v1/models"
856
858
 
859
+ if peft_model_config is not None:
860
+ lora_list = [
861
+ LoRA.from_dict(model) for model in peft_model_config["lora_list"]
862
+ ]
863
+ peft_model = PeftModelConfig(
864
+ lora_list,
865
+ peft_model_config["image_lora_load_kwargs"],
866
+ peft_model_config["image_lora_fuse_kwargs"],
867
+ )
868
+ else:
869
+ peft_model = None
870
+
871
+ # convert float to int or string since the RESTful API does not accept float.
872
+ if isinstance(model_size_in_billions, float):
873
+ model_size_in_billions = convert_float_to_int_or_str(model_size_in_billions)
874
+
857
875
  payload = {
858
876
  "model_uid": model_uid,
859
877
  "model_name": model_name,
878
+ "peft_model_config": peft_model.to_dict() if peft_model else None,
860
879
  "model_type": model_type,
861
880
  "model_size_in_billions": model_size_in_billions,
862
881
  "model_format": model_format,
@@ -864,9 +883,6 @@ class Client:
864
883
  "replica": replica,
865
884
  "n_gpu": n_gpu,
866
885
  "request_limits": request_limits,
867
- "peft_model_path": peft_model_path,
868
- "image_lora_load_kwargs": image_lora_load_kwargs,
869
- "image_lora_fuse_kwargs": image_lora_fuse_kwargs,
870
886
  "worker_ip": worker_ip,
871
887
  "gpu_idx": gpu_idx,
872
888
  }
@@ -30,6 +30,7 @@ from ..constants import (
30
30
  )
31
31
  from ..core import ModelActor
32
32
  from ..core.status_guard import InstanceInfo, LaunchStatus
33
+ from ..types import PeftModelConfig
33
34
  from .metrics import record_metrics
34
35
  from .resource import GPUStatus, ResourceStatus
35
36
  from .utils import (
@@ -135,6 +136,13 @@ class SupervisorActor(xo.StatelessActor):
135
136
  EventCollectorActor, address=self.address, uid=EventCollectorActor.uid()
136
137
  )
137
138
 
139
+ from ..model.audio import (
140
+ CustomAudioModelFamilyV1,
141
+ generate_audio_description,
142
+ get_audio_model_descriptions,
143
+ register_audio,
144
+ unregister_audio,
145
+ )
138
146
  from ..model.embedding import (
139
147
  CustomEmbeddingModelSpec,
140
148
  generate_embedding_description,
@@ -177,6 +185,12 @@ class SupervisorActor(xo.StatelessActor):
177
185
  unregister_rerank,
178
186
  generate_rerank_description,
179
187
  ),
188
+ "audio": (
189
+ CustomAudioModelFamilyV1,
190
+ register_audio,
191
+ unregister_audio,
192
+ generate_audio_description,
193
+ ),
180
194
  }
181
195
 
182
196
  # record model version
@@ -185,6 +199,7 @@ class SupervisorActor(xo.StatelessActor):
185
199
  model_version_infos.update(get_embedding_model_descriptions())
186
200
  model_version_infos.update(get_rerank_model_descriptions())
187
201
  model_version_infos.update(get_image_model_descriptions())
202
+ model_version_infos.update(get_audio_model_descriptions())
188
203
  await self._cache_tracker_ref.record_model_version(
189
204
  model_version_infos, self.address
190
205
  )
@@ -483,6 +498,7 @@ class SupervisorActor(xo.StatelessActor):
483
498
  return ret
484
499
  elif model_type == "audio":
485
500
  from ..model.audio import BUILTIN_AUDIO_MODELS
501
+ from ..model.audio.custom import get_user_defined_audios
486
502
 
487
503
  ret = []
488
504
  for model_name, family in BUILTIN_AUDIO_MODELS.items():
@@ -491,6 +507,16 @@ class SupervisorActor(xo.StatelessActor):
491
507
  else:
492
508
  ret.append({"model_name": model_name, "is_builtin": True})
493
509
 
510
+ for model_spec in get_user_defined_audios():
511
+ if detailed:
512
+ ret.append(
513
+ await self._to_audio_model_reg(model_spec, is_builtin=False)
514
+ )
515
+ else:
516
+ ret.append(
517
+ {"model_name": model_spec.model_name, "is_builtin": False}
518
+ )
519
+
494
520
  ret.sort(key=sort_helper)
495
521
  return ret
496
522
  elif model_type == "rerank":
@@ -548,8 +574,9 @@ class SupervisorActor(xo.StatelessActor):
548
574
  raise ValueError(f"Model {model_name} not found")
549
575
  elif model_type == "audio":
550
576
  from ..model.audio import BUILTIN_AUDIO_MODELS
577
+ from ..model.audio.custom import get_user_defined_audios
551
578
 
552
- for f in BUILTIN_AUDIO_MODELS.values():
579
+ for f in list(BUILTIN_AUDIO_MODELS.values()) + get_user_defined_audios():
553
580
  if f.model_name == model_name:
554
581
  return f
555
582
  raise ValueError(f"Model {model_name} not found")
@@ -654,7 +681,7 @@ class SupervisorActor(xo.StatelessActor):
654
681
  self,
655
682
  model_uid: Optional[str],
656
683
  model_name: str,
657
- model_size_in_billions: Optional[int],
684
+ model_size_in_billions: Optional[Union[int, str]],
658
685
  quantization: Optional[str],
659
686
  draft_model_name: str,
660
687
  draft_model_size_in_billions: Optional[int],
@@ -714,7 +741,7 @@ class SupervisorActor(xo.StatelessActor):
714
741
  self,
715
742
  model_uid: Optional[str],
716
743
  model_name: str,
717
- model_size_in_billions: Optional[int],
744
+ model_size_in_billions: Optional[Union[int, str]],
718
745
  model_format: Optional[str],
719
746
  quantization: Optional[str],
720
747
  model_type: Optional[str],
@@ -723,9 +750,7 @@ class SupervisorActor(xo.StatelessActor):
723
750
  request_limits: Optional[int] = None,
724
751
  wait_ready: bool = True,
725
752
  model_version: Optional[str] = None,
726
- peft_model_path: Optional[str] = None,
727
- image_lora_load_kwargs: Optional[Dict] = None,
728
- image_lora_fuse_kwargs: Optional[Dict] = None,
753
+ peft_model_config: Optional[PeftModelConfig] = None,
729
754
  worker_ip: Optional[str] = None,
730
755
  gpu_idx: Optional[Union[int, List[int]]] = None,
731
756
  **kwargs,
@@ -777,9 +802,7 @@ class SupervisorActor(xo.StatelessActor):
777
802
  model_type=model_type,
778
803
  n_gpu=n_gpu,
779
804
  request_limits=request_limits,
780
- peft_model_path=peft_model_path,
781
- image_lora_load_kwargs=image_lora_load_kwargs,
782
- image_lora_fuse_kwargs=image_lora_fuse_kwargs,
805
+ peft_model_config=peft_model_config,
783
806
  gpu_idx=gpu_idx,
784
807
  **kwargs,
785
808
  )
xinference/core/worker.py CHANGED
@@ -36,6 +36,7 @@ from ..core import ModelActor
36
36
  from ..core.status_guard import LaunchStatus
37
37
  from ..device_utils import gpu_count
38
38
  from ..model.core import ModelDescription, create_model_instance
39
+ from ..types import PeftModelConfig
39
40
  from .event import Event, EventCollectorActor, EventType
40
41
  from .metrics import launch_metrics_export_server, record_metrics
41
42
  from .resource import gather_node_info
@@ -195,6 +196,12 @@ class WorkerActor(xo.StatelessActor):
195
196
  logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
196
197
  purge_dir(XINFERENCE_CACHE_DIR)
197
198
 
199
+ from ..model.audio import (
200
+ CustomAudioModelFamilyV1,
201
+ get_audio_model_descriptions,
202
+ register_audio,
203
+ unregister_audio,
204
+ )
198
205
  from ..model.embedding import (
199
206
  CustomEmbeddingModelSpec,
200
207
  get_embedding_model_descriptions,
@@ -223,6 +230,7 @@ class WorkerActor(xo.StatelessActor):
223
230
  unregister_embedding,
224
231
  ),
225
232
  "rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
233
+ "audio": (CustomAudioModelFamilyV1, register_audio, unregister_audio),
226
234
  }
227
235
 
228
236
  # record model version
@@ -231,6 +239,7 @@ class WorkerActor(xo.StatelessActor):
231
239
  model_version_infos.update(get_embedding_model_descriptions())
232
240
  model_version_infos.update(get_rerank_model_descriptions())
233
241
  model_version_infos.update(get_image_model_descriptions())
242
+ model_version_infos.update(get_audio_model_descriptions())
234
243
  await self._cache_tracker_ref.record_model_version(
235
244
  model_version_infos, self.address
236
245
  )
@@ -593,14 +602,12 @@ class WorkerActor(xo.StatelessActor):
593
602
  self,
594
603
  model_uid: str,
595
604
  model_name: str,
596
- model_size_in_billions: Optional[int],
605
+ model_size_in_billions: Optional[Union[int, str]],
597
606
  model_format: Optional[str],
598
607
  quantization: Optional[str],
599
608
  model_type: str = "LLM",
600
609
  n_gpu: Optional[Union[int, str]] = "auto",
601
- peft_model_path: Optional[str] = None,
602
- image_lora_load_kwargs: Optional[Dict] = None,
603
- image_lora_fuse_kwargs: Optional[Dict] = None,
610
+ peft_model_config: Optional[PeftModelConfig] = None,
604
611
  request_limits: Optional[int] = None,
605
612
  gpu_idx: Optional[Union[int, List[int]]] = None,
606
613
  **kwargs,
@@ -638,7 +645,7 @@ class WorkerActor(xo.StatelessActor):
638
645
  if isinstance(n_gpu, str) and n_gpu != "auto":
639
646
  raise ValueError("Currently `n_gpu` only supports `auto`.")
640
647
 
641
- if peft_model_path is not None:
648
+ if peft_model_config is not None:
642
649
  if model_type in ("embedding", "rerank"):
643
650
  raise ValueError(
644
651
  f"PEFT adaptors cannot be applied to embedding or rerank models."
@@ -669,9 +676,7 @@ class WorkerActor(xo.StatelessActor):
669
676
  model_format,
670
677
  model_size_in_billions,
671
678
  quantization,
672
- peft_model_path,
673
- image_lora_load_kwargs,
674
- image_lora_fuse_kwargs,
679
+ peft_model_config,
675
680
  is_local_deployment,
676
681
  **kwargs,
677
682
  )
@@ -640,10 +640,11 @@ def list_model_registrations(
640
640
  help='The number of GPUs used by the model, default is "auto".',
641
641
  )
642
642
  @click.option(
643
- "--peft-model-path",
644
- default=None,
645
- type=str,
646
- help="PEFT model path.",
643
+ "--lora-modules",
644
+ "-lm",
645
+ multiple=True,
646
+ type=(str, str),
647
+ help="LoRA module configurations in the format name=path. Multiple modules can be specified.",
647
648
  )
648
649
  @click.option(
649
650
  "--image-lora-load-kwargs",
@@ -696,7 +697,7 @@ def model_launch(
696
697
  quantization: str,
697
698
  replica: int,
698
699
  n_gpu: str,
699
- peft_model_path: Optional[str],
700
+ lora_modules: Optional[Tuple],
700
701
  image_lora_load_kwargs: Optional[Tuple],
701
702
  image_lora_fuse_kwargs: Optional[Tuple],
702
703
  worker_ip: Optional[str],
@@ -729,6 +730,18 @@ def model_launch(
729
730
  else None
730
731
  )
731
732
 
733
+ lora_list = (
734
+ [{"lora_name": k, "local_path": v} for k, v in dict(lora_modules).items()]
735
+ if lora_modules
736
+ else []
737
+ )
738
+
739
+ peft_model_config = {
740
+ "image_lora_load_kwargs": image_lora_load_params,
741
+ "image_lora_fuse_kwargs": image_lora_fuse_params,
742
+ "lora_list": lora_list,
743
+ }
744
+
732
745
  _gpu_idx: Optional[List[int]] = (
733
746
  None if gpu_idx is None else [int(idx) for idx in gpu_idx.split(",")]
734
747
  )
@@ -736,7 +749,9 @@ def model_launch(
736
749
  endpoint = get_endpoint(endpoint)
737
750
  model_size: Optional[Union[str, int]] = (
738
751
  size_in_billions
739
- if size_in_billions is None or "_" in size_in_billions
752
+ if size_in_billions is None
753
+ or "_" in size_in_billions
754
+ or "." in size_in_billions
740
755
  else int(size_in_billions)
741
756
  )
742
757
  client = RESTfulClient(base_url=endpoint, api_key=api_key)
@@ -752,9 +767,7 @@ def model_launch(
752
767
  quantization=quantization,
753
768
  replica=replica,
754
769
  n_gpu=_n_gpu,
755
- peft_model_path=peft_model_path,
756
- image_lora_load_kwargs=image_lora_load_params,
757
- image_lora_fuse_kwargs=image_lora_fuse_params,
770
+ peft_model_config=peft_model_config,
758
771
  worker_ip=worker_ip,
759
772
  gpu_idx=_gpu_idx,
760
773
  trust_remote_code=trust_remote_code,
@@ -16,12 +16,51 @@ import codecs
16
16
  import json
17
17
  import os
18
18
 
19
- from .core import AudioModelFamilyV1, generate_audio_description, get_cache_status
19
+ from .core import (
20
+ AUDIO_MODEL_DESCRIPTIONS,
21
+ MODEL_NAME_TO_REVISION,
22
+ AudioModelFamilyV1,
23
+ generate_audio_description,
24
+ get_audio_model_descriptions,
25
+ get_cache_status,
26
+ )
27
+ from .custom import (
28
+ CustomAudioModelFamilyV1,
29
+ get_user_defined_audios,
30
+ register_audio,
31
+ unregister_audio,
32
+ )
20
33
 
21
34
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
22
35
  BUILTIN_AUDIO_MODELS = dict(
23
36
  (spec["model_name"], AudioModelFamilyV1(**spec))
24
37
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
25
38
  )
39
+ for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
40
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
+
42
+ # register model description after recording model revision
43
+ for model_spec_info in [BUILTIN_AUDIO_MODELS]:
44
+ for model_name, model_spec in model_spec_info.items():
45
+ if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
46
+ AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
47
+
48
+ from ...constants import XINFERENCE_MODEL_DIR
49
+
50
+ # if persist=True, load them when init
51
+ user_defined_audio_dir = os.path.join(XINFERENCE_MODEL_DIR, "audio")
52
+ if os.path.isdir(user_defined_audio_dir):
53
+ for f in os.listdir(user_defined_audio_dir):
54
+ with codecs.open(
55
+ os.path.join(user_defined_audio_dir, f), encoding="utf-8"
56
+ ) as fd:
57
+ user_defined_audio_family = CustomAudioModelFamilyV1.parse_obj(
58
+ json.load(fd)
59
+ )
60
+ register_audio(user_defined_audio_family, persist=False)
61
+
62
+ # register model description
63
+ for ud_audio in get_user_defined_audios():
64
+ AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
26
65
 
27
66
  del _model_spec_json
@@ -16,9 +16,8 @@ import os
16
16
  from collections import defaultdict
17
17
  from typing import Dict, List, Optional, Tuple
18
18
 
19
- from ..._compat import BaseModel
20
19
  from ...constants import XINFERENCE_CACHE_DIR
21
- from ..core import ModelDescription
20
+ from ..core import CacheableModelSpec, ModelDescription
22
21
  from ..utils import valid_model_revision
23
22
  from .whisper import WhisperModel
24
23
 
@@ -26,8 +25,19 @@ MAX_ATTEMPTS = 3
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
28
+ # Used for check whether the model is cached.
29
+ # Init when registering all the builtin models.
30
+ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
31
+ AUDIO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
29
32
 
30
- class AudioModelFamilyV1(BaseModel):
33
+
34
+ def get_audio_model_descriptions():
35
+ import copy
36
+
37
+ return copy.deepcopy(AUDIO_MODEL_DESCRIPTIONS)
38
+
39
+
40
+ class AudioModelFamilyV1(CacheableModelSpec):
31
41
  model_family: str
32
42
  model_name: str
33
43
  model_id: str
@@ -77,63 +87,33 @@ def generate_audio_description(
77
87
  image_model: AudioModelFamilyV1,
78
88
  ) -> Dict[str, List[Dict]]:
79
89
  res = defaultdict(list)
80
- res[image_model.model_name].extend(
81
- AudioModelDescription(None, None, image_model).to_dict()
90
+ res[image_model.model_name].append(
91
+ AudioModelDescription(None, None, image_model).to_version_info()
82
92
  )
83
93
  return res
84
94
 
85
95
 
86
- def match_model(model_name: str) -> AudioModelFamilyV1:
96
+ def match_audio(model_name: str) -> AudioModelFamilyV1:
87
97
  from . import BUILTIN_AUDIO_MODELS
98
+ from .custom import get_user_defined_audios
99
+
100
+ for model_spec in get_user_defined_audios():
101
+ if model_spec.model_name == model_name:
102
+ return model_spec
88
103
 
89
104
  if model_name in BUILTIN_AUDIO_MODELS:
90
105
  return BUILTIN_AUDIO_MODELS[model_name]
91
106
  else:
92
107
  raise ValueError(
93
- f"Image model {model_name} not found, available"
108
+ f"Audio model {model_name} not found, available"
94
109
  f"model list: {BUILTIN_AUDIO_MODELS.keys()}"
95
110
  )
96
111
 
97
112
 
98
113
  def cache(model_spec: AudioModelFamilyV1):
99
- # TODO: cache from uri
100
- import huggingface_hub
101
-
102
- cache_dir = get_cache_dir(model_spec)
103
- if not os.path.exists(cache_dir):
104
- os.makedirs(cache_dir, exist_ok=True)
105
-
106
- meta_path = os.path.join(cache_dir, "__valid_download")
107
- if valid_model_revision(meta_path, model_spec.model_revision):
108
- return cache_dir
109
-
110
- for current_attempt in range(1, MAX_ATTEMPTS + 1):
111
- try:
112
- huggingface_hub.snapshot_download(
113
- model_spec.model_id,
114
- revision=model_spec.model_revision,
115
- local_dir=cache_dir,
116
- local_dir_use_symlinks=True,
117
- resume_download=True,
118
- )
119
- break
120
- except huggingface_hub.utils.LocalEntryNotFoundError:
121
- remaining_attempts = MAX_ATTEMPTS - current_attempt
122
- logger.warning(
123
- f"Attempt {current_attempt} failed. Remaining attempts: {remaining_attempts}"
124
- )
125
- else:
126
- raise RuntimeError(
127
- f"Failed to download model '{model_spec.model_name}' after {MAX_ATTEMPTS} attempts"
128
- )
129
-
130
- with open(meta_path, "w") as f:
131
- import json
132
-
133
- desc = AudioModelDescription(None, None, model_spec)
134
- json.dump(desc.to_dict(), f)
114
+ from ..utils import cache
135
115
 
136
- return cache_dir
116
+ return cache(model_spec, AudioModelDescription)
137
117
 
138
118
 
139
119
  def get_cache_dir(model_spec: AudioModelFamilyV1):
@@ -151,7 +131,7 @@ def get_cache_status(
151
131
  def create_audio_model_instance(
152
132
  subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
153
133
  ) -> Tuple[WhisperModel, AudioModelDescription]:
154
- model_spec = match_model(model_name)
134
+ model_spec = match_audio(model_name)
155
135
  model_path = cache(model_spec)
156
136
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
157
137
  model_description = AudioModelDescription(