xinference 0.10.1__py3-none-any.whl → 0.10.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (55) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +9 -9
  3. xinference/client/restful/restful_client.py +29 -16
  4. xinference/core/supervisor.py +32 -9
  5. xinference/core/worker.py +13 -8
  6. xinference/deploy/cmdline.py +22 -9
  7. xinference/model/audio/__init__.py +40 -1
  8. xinference/model/audio/core.py +25 -45
  9. xinference/model/audio/custom.py +148 -0
  10. xinference/model/core.py +6 -9
  11. xinference/model/embedding/model_spec.json +24 -0
  12. xinference/model/embedding/model_spec_modelscope.json +24 -0
  13. xinference/model/image/core.py +12 -4
  14. xinference/model/image/stable_diffusion/core.py +8 -7
  15. xinference/model/llm/core.py +9 -14
  16. xinference/model/llm/llm_family.json +263 -0
  17. xinference/model/llm/llm_family.py +26 -4
  18. xinference/model/llm/llm_family_modelscope.json +160 -0
  19. xinference/model/llm/pytorch/baichuan.py +4 -3
  20. xinference/model/llm/pytorch/chatglm.py +3 -2
  21. xinference/model/llm/pytorch/core.py +15 -13
  22. xinference/model/llm/pytorch/falcon.py +6 -5
  23. xinference/model/llm/pytorch/internlm2.py +3 -2
  24. xinference/model/llm/pytorch/llama_2.py +6 -5
  25. xinference/model/llm/pytorch/vicuna.py +4 -3
  26. xinference/model/llm/vllm/core.py +3 -0
  27. xinference/model/rerank/core.py +23 -12
  28. xinference/model/rerank/model_spec.json +24 -0
  29. xinference/model/rerank/model_spec_modelscope.json +25 -1
  30. xinference/model/utils.py +12 -1
  31. xinference/types.py +55 -0
  32. xinference/utils.py +1 -0
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/main.26fdbfbe.js +3 -0
  36. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/f9290c0738db50065492ceedc6a4af25083fe18399b7c44d942273349ad9e643.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +1 -0
  43. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/METADATA +4 -1
  44. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/RECORD +49 -46
  45. xinference/web/ui/build/static/js/main.76ef2b17.js +0 -3
  46. xinference/web/ui/build/static/js/main.76ef2b17.js.map +0 -1
  47. xinference/web/ui/node_modules/.cache/babel-loader/35d0e4a317e5582cbb79d901302e9d706520ac53f8a734c2fd8bfde6eb5a4f02.json +0 -1
  48. xinference/web/ui/node_modules/.cache/babel-loader/d076fd56cf3b15ed2433e3744b98c6b4e4410a19903d1db4de5bba0e1a1b3347.json +0 -1
  49. xinference/web/ui/node_modules/.cache/babel-loader/daad8131d91134f6d7aef895a0c9c32e1cb928277cb5aa66c01028126d215be0.json +0 -1
  50. xinference/web/ui/node_modules/.cache/babel-loader/f16aec63602a77bd561d0e67fa00b76469ac54b8033754bba114ec5eb3257964.json +0 -1
  51. /xinference/web/ui/build/static/js/{main.76ef2b17.js.LICENSE.txt → main.26fdbfbe.js.LICENSE.txt} +0 -0
  52. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/LICENSE +0 -0
  53. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/WHEEL +0 -0
  54. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/entry_points.txt +0 -0
  55. {xinference-0.10.1.dist-info → xinference-0.10.2.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-04-11T15:35:46+0800",
11
+ "date": "2024-04-19T14:40:59+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e3a947ebddfc53b5e8ec723c1f632c2b895edef1",
15
- "version": "0.10.1"
14
+ "full-revisionid": "500171569de25d49f6ddb3c167d9fc0e55cd66c7",
15
+ "version": "0.10.2.post1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -64,6 +64,7 @@ from ..types import (
64
64
  CreateChatCompletion,
65
65
  CreateCompletion,
66
66
  ImageList,
67
+ PeftModelConfig,
67
68
  max_tokens_field,
68
69
  )
69
70
  from .oauth2.auth_service import AuthService
@@ -692,9 +693,7 @@ class RESTfulAPI:
692
693
  replica = payload.get("replica", 1)
693
694
  n_gpu = payload.get("n_gpu", "auto")
694
695
  request_limits = payload.get("request_limits", None)
695
- peft_model_path = payload.get("peft_model_path", None)
696
- image_lora_load_kwargs = payload.get("image_lora_load_kwargs", None)
697
- image_lora_fuse_kwargs = payload.get("image_lora_fuse_kwargs", None)
696
+ peft_model_config = payload.get("peft_model_config", None)
698
697
  worker_ip = payload.get("worker_ip", None)
699
698
  gpu_idx = payload.get("gpu_idx", None)
700
699
 
@@ -708,9 +707,7 @@ class RESTfulAPI:
708
707
  "replica",
709
708
  "n_gpu",
710
709
  "request_limits",
711
- "peft_model_path",
712
- "image_lora_load_kwargs",
713
- "image_lora_fuse_kwargs",
710
+ "peft_model_config",
714
711
  "worker_ip",
715
712
  "gpu_idx",
716
713
  }
@@ -725,6 +722,11 @@ class RESTfulAPI:
725
722
  detail="Invalid input. Please specify the model name",
726
723
  )
727
724
 
725
+ if peft_model_config is not None:
726
+ peft_model_config = PeftModelConfig.from_dict(peft_model_config)
727
+ else:
728
+ peft_model_config = None
729
+
728
730
  try:
729
731
  model_uid = await (await self._get_supervisor_ref()).launch_builtin_model(
730
732
  model_uid=model_uid,
@@ -737,9 +739,7 @@ class RESTfulAPI:
737
739
  n_gpu=n_gpu,
738
740
  request_limits=request_limits,
739
741
  wait_ready=wait_ready,
740
- peft_model_path=peft_model_path,
741
- image_lora_load_kwargs=image_lora_load_kwargs,
742
- image_lora_fuse_kwargs=image_lora_fuse_kwargs,
742
+ peft_model_config=peft_model_config,
743
743
  worker_ip=worker_ip,
744
744
  gpu_idx=gpu_idx,
745
745
  **kwargs,
@@ -35,6 +35,17 @@ if TYPE_CHECKING:
35
35
  )
36
36
 
37
37
 
38
+ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
39
+ """convert float to int or string
40
+
41
+ if float can be presented as int, convert it to int, otherwise convert it to string
42
+ """
43
+ if int(model_size) == model_size:
44
+ return int(model_size)
45
+ else:
46
+ return str(model_size)
47
+
48
+
38
49
  def _get_error_string(response: requests.Response) -> str:
39
50
  try:
40
51
  if response.content:
@@ -746,7 +757,7 @@ class Client:
746
757
  def launch_speculative_llm(
747
758
  self,
748
759
  model_name: str,
749
- model_size_in_billions: Optional[int],
760
+ model_size_in_billions: Optional[Union[int, str, float]],
750
761
  quantization: Optional[str],
751
762
  draft_model_name: str,
752
763
  draft_model_size_in_billions: Optional[int],
@@ -767,6 +778,10 @@ class Client:
767
778
  "`launch_speculative_llm` is an experimental feature and the API may change in the future."
768
779
  )
769
780
 
781
+ # convert float to int or string since the RESTful API does not accept float.
782
+ if isinstance(model_size_in_billions, float):
783
+ model_size_in_billions = convert_float_to_int_or_str(model_size_in_billions)
784
+
770
785
  payload = {
771
786
  "model_uid": None,
772
787
  "model_name": model_name,
@@ -794,15 +809,13 @@ class Client:
794
809
  model_name: str,
795
810
  model_type: str = "LLM",
796
811
  model_uid: Optional[str] = None,
797
- model_size_in_billions: Optional[Union[int, str]] = None,
812
+ model_size_in_billions: Optional[Union[int, str, float]] = None,
798
813
  model_format: Optional[str] = None,
799
814
  quantization: Optional[str] = None,
800
815
  replica: int = 1,
801
816
  n_gpu: Optional[Union[int, str]] = "auto",
817
+ peft_model_config: Optional[Dict] = None,
802
818
  request_limits: Optional[int] = None,
803
- peft_model_path: Optional[str] = None,
804
- image_lora_load_kwargs: Optional[Dict] = None,
805
- image_lora_fuse_kwargs: Optional[Dict] = None,
806
819
  worker_ip: Optional[str] = None,
807
820
  gpu_idx: Optional[Union[int, List[int]]] = None,
808
821
  **kwargs,
@@ -818,7 +831,7 @@ class Client:
818
831
  type of model.
819
832
  model_uid: str
820
833
  UID of model, auto generate a UUID if is None.
821
- model_size_in_billions: Optional[int]
834
+ model_size_in_billions: Optional[Union[int, str, float]]
822
835
  The size (in billions) of the model.
823
836
  model_format: Optional[str]
824
837
  The format of the model.
@@ -829,15 +842,13 @@ class Client:
829
842
  n_gpu: Optional[Union[int, str]],
830
843
  The number of GPUs used by the model, default is "auto".
831
844
  ``n_gpu=None`` means cpu only, ``n_gpu=auto`` lets the system automatically determine the best number of GPUs to use.
845
+ peft_model_config: Optional[Dict]
846
+ - "lora_list": A List of PEFT (Parameter-Efficient Fine-Tuning) model and path.
847
+ - "image_lora_load_kwargs": A Dict of lora load parameters for image model
848
+ - "image_lora_fuse_kwargs": A Dict of lora fuse parameters for image model
832
849
  request_limits: Optional[int]
833
- The number of request limits for this model default is None.
850
+ The number of request limits for this model, default is None.
834
851
  ``request_limits=None`` means no limits for this model.
835
- peft_model_path: Optional[str]
836
- PEFT (Parameter-Efficient Fine-Tuning) model path.
837
- image_lora_load_kwargs: Optional[Dict]
838
- lora load parameters for image model
839
- image_lora_fuse_kwargs: Optional[Dict]
840
- lora fuse parameters for image model
841
852
  worker_ip: Optional[str]
842
853
  Specify the worker ip where the model is located in a distributed scenario.
843
854
  gpu_idx: Optional[Union[int, List[int]]]
@@ -854,9 +865,14 @@ class Client:
854
865
 
855
866
  url = f"{self.base_url}/v1/models"
856
867
 
868
+ # convert float to int or string since the RESTful API does not accept float.
869
+ if isinstance(model_size_in_billions, float):
870
+ model_size_in_billions = convert_float_to_int_or_str(model_size_in_billions)
871
+
857
872
  payload = {
858
873
  "model_uid": model_uid,
859
874
  "model_name": model_name,
875
+ "peft_model_config": peft_model_config,
860
876
  "model_type": model_type,
861
877
  "model_size_in_billions": model_size_in_billions,
862
878
  "model_format": model_format,
@@ -864,9 +880,6 @@ class Client:
864
880
  "replica": replica,
865
881
  "n_gpu": n_gpu,
866
882
  "request_limits": request_limits,
867
- "peft_model_path": peft_model_path,
868
- "image_lora_load_kwargs": image_lora_load_kwargs,
869
- "image_lora_fuse_kwargs": image_lora_fuse_kwargs,
870
883
  "worker_ip": worker_ip,
871
884
  "gpu_idx": gpu_idx,
872
885
  }
@@ -30,6 +30,7 @@ from ..constants import (
30
30
  )
31
31
  from ..core import ModelActor
32
32
  from ..core.status_guard import InstanceInfo, LaunchStatus
33
+ from ..types import PeftModelConfig
33
34
  from .metrics import record_metrics
34
35
  from .resource import GPUStatus, ResourceStatus
35
36
  from .utils import (
@@ -135,6 +136,13 @@ class SupervisorActor(xo.StatelessActor):
135
136
  EventCollectorActor, address=self.address, uid=EventCollectorActor.uid()
136
137
  )
137
138
 
139
+ from ..model.audio import (
140
+ CustomAudioModelFamilyV1,
141
+ generate_audio_description,
142
+ get_audio_model_descriptions,
143
+ register_audio,
144
+ unregister_audio,
145
+ )
138
146
  from ..model.embedding import (
139
147
  CustomEmbeddingModelSpec,
140
148
  generate_embedding_description,
@@ -177,6 +185,12 @@ class SupervisorActor(xo.StatelessActor):
177
185
  unregister_rerank,
178
186
  generate_rerank_description,
179
187
  ),
188
+ "audio": (
189
+ CustomAudioModelFamilyV1,
190
+ register_audio,
191
+ unregister_audio,
192
+ generate_audio_description,
193
+ ),
180
194
  }
181
195
 
182
196
  # record model version
@@ -185,6 +199,7 @@ class SupervisorActor(xo.StatelessActor):
185
199
  model_version_infos.update(get_embedding_model_descriptions())
186
200
  model_version_infos.update(get_rerank_model_descriptions())
187
201
  model_version_infos.update(get_image_model_descriptions())
202
+ model_version_infos.update(get_audio_model_descriptions())
188
203
  await self._cache_tracker_ref.record_model_version(
189
204
  model_version_infos, self.address
190
205
  )
@@ -483,6 +498,7 @@ class SupervisorActor(xo.StatelessActor):
483
498
  return ret
484
499
  elif model_type == "audio":
485
500
  from ..model.audio import BUILTIN_AUDIO_MODELS
501
+ from ..model.audio.custom import get_user_defined_audios
486
502
 
487
503
  ret = []
488
504
  for model_name, family in BUILTIN_AUDIO_MODELS.items():
@@ -491,6 +507,16 @@ class SupervisorActor(xo.StatelessActor):
491
507
  else:
492
508
  ret.append({"model_name": model_name, "is_builtin": True})
493
509
 
510
+ for model_spec in get_user_defined_audios():
511
+ if detailed:
512
+ ret.append(
513
+ await self._to_audio_model_reg(model_spec, is_builtin=False)
514
+ )
515
+ else:
516
+ ret.append(
517
+ {"model_name": model_spec.model_name, "is_builtin": False}
518
+ )
519
+
494
520
  ret.sort(key=sort_helper)
495
521
  return ret
496
522
  elif model_type == "rerank":
@@ -548,8 +574,9 @@ class SupervisorActor(xo.StatelessActor):
548
574
  raise ValueError(f"Model {model_name} not found")
549
575
  elif model_type == "audio":
550
576
  from ..model.audio import BUILTIN_AUDIO_MODELS
577
+ from ..model.audio.custom import get_user_defined_audios
551
578
 
552
- for f in BUILTIN_AUDIO_MODELS.values():
579
+ for f in list(BUILTIN_AUDIO_MODELS.values()) + get_user_defined_audios():
553
580
  if f.model_name == model_name:
554
581
  return f
555
582
  raise ValueError(f"Model {model_name} not found")
@@ -654,7 +681,7 @@ class SupervisorActor(xo.StatelessActor):
654
681
  self,
655
682
  model_uid: Optional[str],
656
683
  model_name: str,
657
- model_size_in_billions: Optional[int],
684
+ model_size_in_billions: Optional[Union[int, str]],
658
685
  quantization: Optional[str],
659
686
  draft_model_name: str,
660
687
  draft_model_size_in_billions: Optional[int],
@@ -714,7 +741,7 @@ class SupervisorActor(xo.StatelessActor):
714
741
  self,
715
742
  model_uid: Optional[str],
716
743
  model_name: str,
717
- model_size_in_billions: Optional[int],
744
+ model_size_in_billions: Optional[Union[int, str]],
718
745
  model_format: Optional[str],
719
746
  quantization: Optional[str],
720
747
  model_type: Optional[str],
@@ -723,9 +750,7 @@ class SupervisorActor(xo.StatelessActor):
723
750
  request_limits: Optional[int] = None,
724
751
  wait_ready: bool = True,
725
752
  model_version: Optional[str] = None,
726
- peft_model_path: Optional[str] = None,
727
- image_lora_load_kwargs: Optional[Dict] = None,
728
- image_lora_fuse_kwargs: Optional[Dict] = None,
753
+ peft_model_config: Optional[PeftModelConfig] = None,
729
754
  worker_ip: Optional[str] = None,
730
755
  gpu_idx: Optional[Union[int, List[int]]] = None,
731
756
  **kwargs,
@@ -777,9 +802,7 @@ class SupervisorActor(xo.StatelessActor):
777
802
  model_type=model_type,
778
803
  n_gpu=n_gpu,
779
804
  request_limits=request_limits,
780
- peft_model_path=peft_model_path,
781
- image_lora_load_kwargs=image_lora_load_kwargs,
782
- image_lora_fuse_kwargs=image_lora_fuse_kwargs,
805
+ peft_model_config=peft_model_config,
783
806
  gpu_idx=gpu_idx,
784
807
  **kwargs,
785
808
  )
xinference/core/worker.py CHANGED
@@ -36,6 +36,7 @@ from ..core import ModelActor
36
36
  from ..core.status_guard import LaunchStatus
37
37
  from ..device_utils import gpu_count
38
38
  from ..model.core import ModelDescription, create_model_instance
39
+ from ..types import PeftModelConfig
39
40
  from .event import Event, EventCollectorActor, EventType
40
41
  from .metrics import launch_metrics_export_server, record_metrics
41
42
  from .resource import gather_node_info
@@ -195,6 +196,12 @@ class WorkerActor(xo.StatelessActor):
195
196
  logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
196
197
  purge_dir(XINFERENCE_CACHE_DIR)
197
198
 
199
+ from ..model.audio import (
200
+ CustomAudioModelFamilyV1,
201
+ get_audio_model_descriptions,
202
+ register_audio,
203
+ unregister_audio,
204
+ )
198
205
  from ..model.embedding import (
199
206
  CustomEmbeddingModelSpec,
200
207
  get_embedding_model_descriptions,
@@ -223,6 +230,7 @@ class WorkerActor(xo.StatelessActor):
223
230
  unregister_embedding,
224
231
  ),
225
232
  "rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
233
+ "audio": (CustomAudioModelFamilyV1, register_audio, unregister_audio),
226
234
  }
227
235
 
228
236
  # record model version
@@ -231,6 +239,7 @@ class WorkerActor(xo.StatelessActor):
231
239
  model_version_infos.update(get_embedding_model_descriptions())
232
240
  model_version_infos.update(get_rerank_model_descriptions())
233
241
  model_version_infos.update(get_image_model_descriptions())
242
+ model_version_infos.update(get_audio_model_descriptions())
234
243
  await self._cache_tracker_ref.record_model_version(
235
244
  model_version_infos, self.address
236
245
  )
@@ -593,14 +602,12 @@ class WorkerActor(xo.StatelessActor):
593
602
  self,
594
603
  model_uid: str,
595
604
  model_name: str,
596
- model_size_in_billions: Optional[int],
605
+ model_size_in_billions: Optional[Union[int, str]],
597
606
  model_format: Optional[str],
598
607
  quantization: Optional[str],
599
608
  model_type: str = "LLM",
600
609
  n_gpu: Optional[Union[int, str]] = "auto",
601
- peft_model_path: Optional[str] = None,
602
- image_lora_load_kwargs: Optional[Dict] = None,
603
- image_lora_fuse_kwargs: Optional[Dict] = None,
610
+ peft_model_config: Optional[PeftModelConfig] = None,
604
611
  request_limits: Optional[int] = None,
605
612
  gpu_idx: Optional[Union[int, List[int]]] = None,
606
613
  **kwargs,
@@ -638,7 +645,7 @@ class WorkerActor(xo.StatelessActor):
638
645
  if isinstance(n_gpu, str) and n_gpu != "auto":
639
646
  raise ValueError("Currently `n_gpu` only supports `auto`.")
640
647
 
641
- if peft_model_path is not None:
648
+ if peft_model_config is not None:
642
649
  if model_type in ("embedding", "rerank"):
643
650
  raise ValueError(
644
651
  f"PEFT adaptors cannot be applied to embedding or rerank models."
@@ -669,9 +676,7 @@ class WorkerActor(xo.StatelessActor):
669
676
  model_format,
670
677
  model_size_in_billions,
671
678
  quantization,
672
- peft_model_path,
673
- image_lora_load_kwargs,
674
- image_lora_fuse_kwargs,
679
+ peft_model_config,
675
680
  is_local_deployment,
676
681
  **kwargs,
677
682
  )
@@ -640,10 +640,11 @@ def list_model_registrations(
640
640
  help='The number of GPUs used by the model, default is "auto".',
641
641
  )
642
642
  @click.option(
643
- "--peft-model-path",
644
- default=None,
645
- type=str,
646
- help="PEFT model path.",
643
+ "--lora-modules",
644
+ "-lm",
645
+ multiple=True,
646
+ type=(str, str),
647
+ help="LoRA module configurations in the format name=path. Multiple modules can be specified.",
647
648
  )
648
649
  @click.option(
649
650
  "--image-lora-load-kwargs",
@@ -696,7 +697,7 @@ def model_launch(
696
697
  quantization: str,
697
698
  replica: int,
698
699
  n_gpu: str,
699
- peft_model_path: Optional[str],
700
+ lora_modules: Optional[Tuple],
700
701
  image_lora_load_kwargs: Optional[Tuple],
701
702
  image_lora_fuse_kwargs: Optional[Tuple],
702
703
  worker_ip: Optional[str],
@@ -729,6 +730,18 @@ def model_launch(
729
730
  else None
730
731
  )
731
732
 
733
+ lora_list = (
734
+ [{"lora_name": k, "local_path": v} for k, v in dict(lora_modules).items()]
735
+ if lora_modules
736
+ else []
737
+ )
738
+
739
+ peft_model_config = {
740
+ "image_lora_load_kwargs": image_lora_load_params,
741
+ "image_lora_fuse_kwargs": image_lora_fuse_params,
742
+ "lora_list": lora_list,
743
+ }
744
+
732
745
  _gpu_idx: Optional[List[int]] = (
733
746
  None if gpu_idx is None else [int(idx) for idx in gpu_idx.split(",")]
734
747
  )
@@ -736,7 +749,9 @@ def model_launch(
736
749
  endpoint = get_endpoint(endpoint)
737
750
  model_size: Optional[Union[str, int]] = (
738
751
  size_in_billions
739
- if size_in_billions is None or "_" in size_in_billions
752
+ if size_in_billions is None
753
+ or "_" in size_in_billions
754
+ or "." in size_in_billions
740
755
  else int(size_in_billions)
741
756
  )
742
757
  client = RESTfulClient(base_url=endpoint, api_key=api_key)
@@ -752,9 +767,7 @@ def model_launch(
752
767
  quantization=quantization,
753
768
  replica=replica,
754
769
  n_gpu=_n_gpu,
755
- peft_model_path=peft_model_path,
756
- image_lora_load_kwargs=image_lora_load_params,
757
- image_lora_fuse_kwargs=image_lora_fuse_params,
770
+ peft_model_config=peft_model_config,
758
771
  worker_ip=worker_ip,
759
772
  gpu_idx=_gpu_idx,
760
773
  trust_remote_code=trust_remote_code,
@@ -16,12 +16,51 @@ import codecs
16
16
  import json
17
17
  import os
18
18
 
19
- from .core import AudioModelFamilyV1, generate_audio_description, get_cache_status
19
+ from .core import (
20
+ AUDIO_MODEL_DESCRIPTIONS,
21
+ MODEL_NAME_TO_REVISION,
22
+ AudioModelFamilyV1,
23
+ generate_audio_description,
24
+ get_audio_model_descriptions,
25
+ get_cache_status,
26
+ )
27
+ from .custom import (
28
+ CustomAudioModelFamilyV1,
29
+ get_user_defined_audios,
30
+ register_audio,
31
+ unregister_audio,
32
+ )
20
33
 
21
34
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
22
35
  BUILTIN_AUDIO_MODELS = dict(
23
36
  (spec["model_name"], AudioModelFamilyV1(**spec))
24
37
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
25
38
  )
39
+ for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
40
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
+
42
+ # register model description after recording model revision
43
+ for model_spec_info in [BUILTIN_AUDIO_MODELS]:
44
+ for model_name, model_spec in model_spec_info.items():
45
+ if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
46
+ AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
47
+
48
+ from ...constants import XINFERENCE_MODEL_DIR
49
+
50
+ # if persist=True, load them when init
51
+ user_defined_audio_dir = os.path.join(XINFERENCE_MODEL_DIR, "audio")
52
+ if os.path.isdir(user_defined_audio_dir):
53
+ for f in os.listdir(user_defined_audio_dir):
54
+ with codecs.open(
55
+ os.path.join(user_defined_audio_dir, f), encoding="utf-8"
56
+ ) as fd:
57
+ user_defined_audio_family = CustomAudioModelFamilyV1.parse_obj(
58
+ json.load(fd)
59
+ )
60
+ register_audio(user_defined_audio_family, persist=False)
61
+
62
+ # register model description
63
+ for ud_audio in get_user_defined_audios():
64
+ AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
26
65
 
27
66
  del _model_spec_json
@@ -16,9 +16,8 @@ import os
16
16
  from collections import defaultdict
17
17
  from typing import Dict, List, Optional, Tuple
18
18
 
19
- from ..._compat import BaseModel
20
19
  from ...constants import XINFERENCE_CACHE_DIR
21
- from ..core import ModelDescription
20
+ from ..core import CacheableModelSpec, ModelDescription
22
21
  from ..utils import valid_model_revision
23
22
  from .whisper import WhisperModel
24
23
 
@@ -26,8 +25,19 @@ MAX_ATTEMPTS = 3
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
28
+ # Used for check whether the model is cached.
29
+ # Init when registering all the builtin models.
30
+ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
31
+ AUDIO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
29
32
 
30
- class AudioModelFamilyV1(BaseModel):
33
+
34
+ def get_audio_model_descriptions():
35
+ import copy
36
+
37
+ return copy.deepcopy(AUDIO_MODEL_DESCRIPTIONS)
38
+
39
+
40
+ class AudioModelFamilyV1(CacheableModelSpec):
31
41
  model_family: str
32
42
  model_name: str
33
43
  model_id: str
@@ -77,63 +87,33 @@ def generate_audio_description(
77
87
  image_model: AudioModelFamilyV1,
78
88
  ) -> Dict[str, List[Dict]]:
79
89
  res = defaultdict(list)
80
- res[image_model.model_name].extend(
81
- AudioModelDescription(None, None, image_model).to_dict()
90
+ res[image_model.model_name].append(
91
+ AudioModelDescription(None, None, image_model).to_version_info()
82
92
  )
83
93
  return res
84
94
 
85
95
 
86
- def match_model(model_name: str) -> AudioModelFamilyV1:
96
+ def match_audio(model_name: str) -> AudioModelFamilyV1:
87
97
  from . import BUILTIN_AUDIO_MODELS
98
+ from .custom import get_user_defined_audios
99
+
100
+ for model_spec in get_user_defined_audios():
101
+ if model_spec.model_name == model_name:
102
+ return model_spec
88
103
 
89
104
  if model_name in BUILTIN_AUDIO_MODELS:
90
105
  return BUILTIN_AUDIO_MODELS[model_name]
91
106
  else:
92
107
  raise ValueError(
93
- f"Image model {model_name} not found, available"
108
+ f"Audio model {model_name} not found, available"
94
109
  f"model list: {BUILTIN_AUDIO_MODELS.keys()}"
95
110
  )
96
111
 
97
112
 
98
113
  def cache(model_spec: AudioModelFamilyV1):
99
- # TODO: cache from uri
100
- import huggingface_hub
101
-
102
- cache_dir = get_cache_dir(model_spec)
103
- if not os.path.exists(cache_dir):
104
- os.makedirs(cache_dir, exist_ok=True)
105
-
106
- meta_path = os.path.join(cache_dir, "__valid_download")
107
- if valid_model_revision(meta_path, model_spec.model_revision):
108
- return cache_dir
109
-
110
- for current_attempt in range(1, MAX_ATTEMPTS + 1):
111
- try:
112
- huggingface_hub.snapshot_download(
113
- model_spec.model_id,
114
- revision=model_spec.model_revision,
115
- local_dir=cache_dir,
116
- local_dir_use_symlinks=True,
117
- resume_download=True,
118
- )
119
- break
120
- except huggingface_hub.utils.LocalEntryNotFoundError:
121
- remaining_attempts = MAX_ATTEMPTS - current_attempt
122
- logger.warning(
123
- f"Attempt {current_attempt} failed. Remaining attempts: {remaining_attempts}"
124
- )
125
- else:
126
- raise RuntimeError(
127
- f"Failed to download model '{model_spec.model_name}' after {MAX_ATTEMPTS} attempts"
128
- )
129
-
130
- with open(meta_path, "w") as f:
131
- import json
132
-
133
- desc = AudioModelDescription(None, None, model_spec)
134
- json.dump(desc.to_dict(), f)
114
+ from ..utils import cache
135
115
 
136
- return cache_dir
116
+ return cache(model_spec, AudioModelDescription)
137
117
 
138
118
 
139
119
  def get_cache_dir(model_spec: AudioModelFamilyV1):
@@ -151,7 +131,7 @@ def get_cache_status(
151
131
  def create_audio_model_instance(
152
132
  subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
153
133
  ) -> Tuple[WhisperModel, AudioModelDescription]:
154
- model_spec = match_model(model_name)
134
+ model_spec = match_audio(model_name)
155
135
  model_path = cache(model_spec)
156
136
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
157
137
  model_description = AudioModelDescription(