xinference 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +123 -3
  3. xinference/client/restful/restful_client.py +131 -2
  4. xinference/core/model.py +93 -24
  5. xinference/core/supervisor.py +132 -15
  6. xinference/core/worker.py +165 -8
  7. xinference/deploy/cmdline.py +5 -0
  8. xinference/model/audio/chattts.py +46 -14
  9. xinference/model/audio/core.py +23 -15
  10. xinference/model/core.py +12 -3
  11. xinference/model/embedding/core.py +25 -16
  12. xinference/model/flexible/__init__.py +40 -0
  13. xinference/model/flexible/core.py +228 -0
  14. xinference/model/flexible/launchers/__init__.py +15 -0
  15. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  16. xinference/model/flexible/utils.py +33 -0
  17. xinference/model/image/core.py +21 -14
  18. xinference/model/image/custom.py +1 -1
  19. xinference/model/image/model_spec.json +14 -0
  20. xinference/model/image/stable_diffusion/core.py +43 -6
  21. xinference/model/llm/__init__.py +0 -2
  22. xinference/model/llm/core.py +3 -2
  23. xinference/model/llm/ggml/llamacpp.py +1 -10
  24. xinference/model/llm/llm_family.json +292 -36
  25. xinference/model/llm/llm_family.py +97 -52
  26. xinference/model/llm/llm_family_modelscope.json +220 -27
  27. xinference/model/llm/pytorch/core.py +0 -80
  28. xinference/model/llm/sglang/core.py +7 -2
  29. xinference/model/llm/utils.py +4 -2
  30. xinference/model/llm/vllm/core.py +3 -0
  31. xinference/model/rerank/core.py +24 -25
  32. xinference/types.py +0 -1
  33. xinference/web/ui/build/asset-manifest.json +3 -3
  34. xinference/web/ui/build/index.html +1 -1
  35. xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
  36. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  43. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/METADATA +9 -11
  44. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/RECORD +49 -58
  45. xinference/model/llm/ggml/chatglm.py +0 -457
  46. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  47. xinference/thirdparty/ChatTTS/core.py +0 -200
  48. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  49. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  50. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  51. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  52. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  53. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  54. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  55. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  56. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  57. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  58. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  59. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
  66. /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  67. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/LICENSE +0 -0
  68. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/WHEEL +0 -0
  69. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,17 @@ import time
20
20
  import typing
21
21
  from dataclasses import dataclass
22
22
  from logging import getLogger
23
- from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
23
+ from typing import (
24
+ TYPE_CHECKING,
25
+ Any,
26
+ Dict,
27
+ Iterator,
28
+ List,
29
+ Literal,
30
+ Optional,
31
+ Tuple,
32
+ Union,
33
+ )
24
34
 
25
35
  import xoscar as xo
26
36
 
@@ -50,6 +60,7 @@ from .utils import (
50
60
  if TYPE_CHECKING:
51
61
  from ..model.audio import AudioModelFamilyV1
52
62
  from ..model.embedding import EmbeddingModelSpec
63
+ from ..model.flexible import FlexibleModelSpec
53
64
  from ..model.image import ImageModelFamilyV1
54
65
  from ..model.llm import LLMFamilyV1
55
66
  from ..model.rerank import RerankModelSpec
@@ -153,6 +164,13 @@ class SupervisorActor(xo.StatelessActor):
153
164
  register_embedding,
154
165
  unregister_embedding,
155
166
  )
167
+ from ..model.flexible import (
168
+ FlexibleModelSpec,
169
+ generate_flexible_model_description,
170
+ get_flexible_model_descriptions,
171
+ register_flexible_model,
172
+ unregister_flexible_model,
173
+ )
156
174
  from ..model.image import (
157
175
  CustomImageModelFamilyV1,
158
176
  generate_image_description,
@@ -206,6 +224,12 @@ class SupervisorActor(xo.StatelessActor):
206
224
  unregister_audio,
207
225
  generate_audio_description,
208
226
  ),
227
+ "flexible": (
228
+ FlexibleModelSpec,
229
+ register_flexible_model,
230
+ unregister_flexible_model,
231
+ generate_flexible_model_description,
232
+ ),
209
233
  }
210
234
 
211
235
  # record model version
@@ -215,6 +239,7 @@ class SupervisorActor(xo.StatelessActor):
215
239
  model_version_infos.update(get_rerank_model_descriptions())
216
240
  model_version_infos.update(get_image_model_descriptions())
217
241
  model_version_infos.update(get_audio_model_descriptions())
242
+ model_version_infos.update(get_flexible_model_descriptions())
218
243
  await self._cache_tracker_ref.record_model_version(
219
244
  model_version_infos, self.address
220
245
  )
@@ -459,6 +484,27 @@ class SupervisorActor(xo.StatelessActor):
459
484
  res["model_instance_count"] = instance_cnt
460
485
  return res
461
486
 
487
+ async def _to_flexible_model_reg(
488
+ self, model_spec: "FlexibleModelSpec", is_builtin: bool
489
+ ) -> Dict[str, Any]:
490
+ instance_cnt = await self.get_instance_count(model_spec.model_name)
491
+ version_cnt = await self.get_model_version_count(model_spec.model_name)
492
+
493
+ if self.is_local_deployment():
494
+ res = {
495
+ **model_spec.dict(),
496
+ "cache_status": True,
497
+ "is_builtin": is_builtin,
498
+ }
499
+ else:
500
+ res = {
501
+ **model_spec.dict(),
502
+ "is_builtin": is_builtin,
503
+ }
504
+ res["model_version_count"] = version_cnt
505
+ res["model_instance_count"] = instance_cnt
506
+ return res
507
+
462
508
  @log_async(logger=logger)
463
509
  async def list_model_registrations(
464
510
  self, model_type: str, detailed: bool = False
@@ -467,10 +513,15 @@ class SupervisorActor(xo.StatelessActor):
467
513
  assert isinstance(item["model_name"], str)
468
514
  return item.get("model_name").lower()
469
515
 
516
+ ret = []
517
+ if not self.is_local_deployment():
518
+ workers = list(self._worker_address_to_worker.values())
519
+ for worker in workers:
520
+ ret.extend(await worker.list_model_registrations(model_type, detailed))
521
+
470
522
  if model_type == "LLM":
471
523
  from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
472
524
 
473
- ret = []
474
525
  for family in BUILTIN_LLM_FAMILIES:
475
526
  if detailed:
476
527
  ret.append(await self._to_llm_reg(family, True))
@@ -489,7 +540,6 @@ class SupervisorActor(xo.StatelessActor):
489
540
  from ..model.embedding import BUILTIN_EMBEDDING_MODELS
490
541
  from ..model.embedding.custom import get_user_defined_embeddings
491
542
 
492
- ret = []
493
543
  for model_name, family in BUILTIN_EMBEDDING_MODELS.items():
494
544
  if detailed:
495
545
  ret.append(
@@ -514,7 +564,6 @@ class SupervisorActor(xo.StatelessActor):
514
564
  from ..model.image import BUILTIN_IMAGE_MODELS
515
565
  from ..model.image.custom import get_user_defined_images
516
566
 
517
- ret = []
518
567
  for model_name, family in BUILTIN_IMAGE_MODELS.items():
519
568
  if detailed:
520
569
  ret.append(await self._to_image_model_reg(family, is_builtin=True))
@@ -537,7 +586,6 @@ class SupervisorActor(xo.StatelessActor):
537
586
  from ..model.audio import BUILTIN_AUDIO_MODELS
538
587
  from ..model.audio.custom import get_user_defined_audios
539
588
 
540
- ret = []
541
589
  for model_name, family in BUILTIN_AUDIO_MODELS.items():
542
590
  if detailed:
543
591
  ret.append(await self._to_audio_model_reg(family, is_builtin=True))
@@ -560,7 +608,6 @@ class SupervisorActor(xo.StatelessActor):
560
608
  from ..model.rerank import BUILTIN_RERANK_MODELS
561
609
  from ..model.rerank.custom import get_user_defined_reranks
562
610
 
563
- ret = []
564
611
  for model_name, family in BUILTIN_RERANK_MODELS.items():
565
612
  if detailed:
566
613
  ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
@@ -577,13 +624,38 @@ class SupervisorActor(xo.StatelessActor):
577
624
  {"model_name": model_spec.model_name, "is_builtin": False}
578
625
  )
579
626
 
627
+ ret.sort(key=sort_helper)
628
+ return ret
629
+ elif model_type == "flexible":
630
+ from ..model.flexible import get_flexible_models
631
+
632
+ ret = []
633
+
634
+ for model_spec in get_flexible_models():
635
+ if detailed:
636
+ ret.append(
637
+ await self._to_flexible_model_reg(model_spec, is_builtin=False)
638
+ )
639
+ else:
640
+ ret.append(
641
+ {"model_name": model_spec.model_name, "is_builtin": False}
642
+ )
643
+
580
644
  ret.sort(key=sort_helper)
581
645
  return ret
582
646
  else:
583
647
  raise ValueError(f"Unsupported model type: {model_type}")
584
648
 
585
649
  @log_sync(logger=logger)
586
- def get_model_registration(self, model_type: str, model_name: str) -> Any:
650
+ async def get_model_registration(self, model_type: str, model_name: str) -> Any:
651
+ # search in worker first
652
+ if not self.is_local_deployment():
653
+ workers = list(self._worker_address_to_worker.values())
654
+ for worker in workers:
655
+ f = await worker.get_model_registration(model_type, model_name)
656
+ if f is not None:
657
+ return f
658
+
587
659
  if model_type == "LLM":
588
660
  from ..model.llm import BUILTIN_LLM_FAMILIES, get_user_defined_llm_families
589
661
 
@@ -626,6 +698,13 @@ class SupervisorActor(xo.StatelessActor):
626
698
  if f.model_name == model_name:
627
699
  return f
628
700
  raise ValueError(f"Model {model_name} not found")
701
+ elif model_type == "flexible":
702
+ from ..model.flexible import get_flexible_models
703
+
704
+ for f in get_flexible_models():
705
+ if f.model_name == model_name:
706
+ return f
707
+ raise ValueError(f"Model {model_name} not found")
629
708
  else:
630
709
  raise ValueError(f"Unsupported model type: {model_type}")
631
710
 
@@ -635,6 +714,13 @@ class SupervisorActor(xo.StatelessActor):
635
714
 
636
715
  from ..model.llm.llm_family import LLM_ENGINES
637
716
 
717
+ # search in worker first
718
+ workers = list(self._worker_address_to_worker.values())
719
+ for worker in workers:
720
+ res = await worker.query_engines_by_model_name(model_name)
721
+ if res is not None:
722
+ return res
723
+
638
724
  if model_name not in LLM_ENGINES:
639
725
  raise ValueError(f"Model {model_name} not found")
640
726
 
@@ -648,7 +734,13 @@ class SupervisorActor(xo.StatelessActor):
648
734
  return engine_params
649
735
 
650
736
  @log_async(logger=logger)
651
- async def register_model(self, model_type: str, model: str, persist: bool):
737
+ async def register_model(
738
+ self,
739
+ model_type: str,
740
+ model: str,
741
+ persist: bool,
742
+ worker_ip: Optional[str] = None,
743
+ ):
652
744
  if model_type in self._custom_register_type_to_cls:
653
745
  (
654
746
  model_spec_cls,
@@ -657,10 +749,21 @@ class SupervisorActor(xo.StatelessActor):
657
749
  generate_fn,
658
750
  ) = self._custom_register_type_to_cls[model_type]
659
751
 
660
- if not self.is_local_deployment():
661
- workers = list(self._worker_address_to_worker.values())
662
- for worker in workers:
663
- await worker.register_model(model_type, model, persist)
752
+ target_ip_worker_ref = (
753
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
754
+ )
755
+ if (
756
+ worker_ip is not None
757
+ and not self.is_local_deployment()
758
+ and target_ip_worker_ref is None
759
+ ):
760
+ raise ValueError(
761
+ f"Worker ip address {worker_ip} is not in the cluster."
762
+ )
763
+
764
+ if target_ip_worker_ref:
765
+ await target_ip_worker_ref.register_model(model_type, model, persist)
766
+ return
664
767
 
665
768
  model_spec = model_spec_cls.parse_raw(model)
666
769
  try:
@@ -668,6 +771,8 @@ class SupervisorActor(xo.StatelessActor):
668
771
  await self._cache_tracker_ref.record_model_version(
669
772
  generate_fn(model_spec), self.address
670
773
  )
774
+ except ValueError as e:
775
+ raise e
671
776
  except Exception as e:
672
777
  unregister_fn(model_spec.model_name, raise_error=False)
673
778
  raise e
@@ -678,13 +783,14 @@ class SupervisorActor(xo.StatelessActor):
678
783
  async def unregister_model(self, model_type: str, model_name: str):
679
784
  if model_type in self._custom_register_type_to_cls:
680
785
  _, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
681
- unregister_fn(model_name)
682
- await self._cache_tracker_ref.unregister_model_version(model_name)
786
+ unregister_fn(model_name, False)
683
787
 
684
788
  if not self.is_local_deployment():
685
789
  workers = list(self._worker_address_to_worker.values())
686
790
  for worker in workers:
687
- await worker.unregister_model(model_name)
791
+ await worker.unregister_model(model_type, model_name)
792
+
793
+ await self._cache_tracker_ref.unregister_model_version(model_name)
688
794
  else:
689
795
  raise ValueError(f"Unsupported model type: {model_type}")
690
796
 
@@ -752,8 +858,17 @@ class SupervisorActor(xo.StatelessActor):
752
858
  peft_model_config: Optional[PeftModelConfig] = None,
753
859
  worker_ip: Optional[str] = None,
754
860
  gpu_idx: Optional[Union[int, List[int]]] = None,
861
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
755
862
  **kwargs,
756
863
  ) -> str:
864
+ # search in worker first
865
+ if not self.is_local_deployment():
866
+ workers = list(self._worker_address_to_worker.values())
867
+ for worker in workers:
868
+ res = await worker.get_model_registration(model_type, model_name)
869
+ if res is not None:
870
+ worker_ip = worker.address.split(":")[0]
871
+
757
872
  target_ip_worker_ref = (
758
873
  self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
759
874
  )
@@ -806,6 +921,7 @@ class SupervisorActor(xo.StatelessActor):
806
921
  )
807
922
  replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
808
923
  nonlocal model_type
924
+
809
925
  worker_ref = (
810
926
  target_ip_worker_ref
811
927
  if target_ip_worker_ref is not None
@@ -825,6 +941,7 @@ class SupervisorActor(xo.StatelessActor):
825
941
  request_limits=request_limits,
826
942
  peft_model_config=peft_model_config,
827
943
  gpu_idx=replica_gpu_idx,
944
+ download_hub=download_hub,
828
945
  **kwargs,
829
946
  )
830
947
  self._replica_model_uid_to_worker[_replica_model_uid] = worker_ref
xinference/core/worker.py CHANGED
@@ -22,7 +22,7 @@ import threading
22
22
  import time
23
23
  from collections import defaultdict
24
24
  from logging import getLogger
25
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
25
+ from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
26
26
 
27
27
  import xoscar as xo
28
28
  from async_timeout import timeout
@@ -212,48 +212,81 @@ class WorkerActor(xo.StatelessActor):
212
212
 
213
213
  from ..model.audio import (
214
214
  CustomAudioModelFamilyV1,
215
+ generate_audio_description,
215
216
  get_audio_model_descriptions,
216
217
  register_audio,
217
218
  unregister_audio,
218
219
  )
219
220
  from ..model.embedding import (
220
221
  CustomEmbeddingModelSpec,
222
+ generate_embedding_description,
221
223
  get_embedding_model_descriptions,
222
224
  register_embedding,
223
225
  unregister_embedding,
224
226
  )
227
+ from ..model.flexible import (
228
+ FlexibleModelSpec,
229
+ get_flexible_model_descriptions,
230
+ register_flexible_model,
231
+ unregister_flexible_model,
232
+ )
225
233
  from ..model.image import (
226
234
  CustomImageModelFamilyV1,
235
+ generate_image_description,
227
236
  get_image_model_descriptions,
228
237
  register_image,
229
238
  unregister_image,
230
239
  )
231
240
  from ..model.llm import (
232
241
  CustomLLMFamilyV1,
242
+ generate_llm_description,
233
243
  get_llm_model_descriptions,
234
244
  register_llm,
235
245
  unregister_llm,
236
246
  )
237
247
  from ..model.rerank import (
238
248
  CustomRerankModelSpec,
249
+ generate_rerank_description,
239
250
  get_rerank_model_descriptions,
240
251
  register_rerank,
241
252
  unregister_rerank,
242
253
  )
243
254
 
244
255
  self._custom_register_type_to_cls: Dict[str, Tuple] = { # type: ignore
245
- "LLM": (CustomLLMFamilyV1, register_llm, unregister_llm),
256
+ "LLM": (
257
+ CustomLLMFamilyV1,
258
+ register_llm,
259
+ unregister_llm,
260
+ generate_llm_description,
261
+ ),
246
262
  "embedding": (
247
263
  CustomEmbeddingModelSpec,
248
264
  register_embedding,
249
265
  unregister_embedding,
266
+ generate_embedding_description,
267
+ ),
268
+ "rerank": (
269
+ CustomRerankModelSpec,
270
+ register_rerank,
271
+ unregister_rerank,
272
+ generate_rerank_description,
250
273
  ),
251
- "rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
252
- "audio": (CustomAudioModelFamilyV1, register_audio, unregister_audio),
253
274
  "image": (
254
275
  CustomImageModelFamilyV1,
255
276
  register_image,
256
277
  unregister_image,
278
+ generate_image_description,
279
+ ),
280
+ "audio": (
281
+ CustomAudioModelFamilyV1,
282
+ register_audio,
283
+ unregister_audio,
284
+ generate_audio_description,
285
+ ),
286
+ "flexible": (
287
+ FlexibleModelSpec,
288
+ register_flexible_model,
289
+ unregister_flexible_model,
257
290
  ),
258
291
  }
259
292
 
@@ -264,6 +297,7 @@ class WorkerActor(xo.StatelessActor):
264
297
  model_version_infos.update(get_rerank_model_descriptions())
265
298
  model_version_infos.update(get_image_model_descriptions())
266
299
  model_version_infos.update(get_audio_model_descriptions())
300
+ model_version_infos.update(get_flexible_model_descriptions())
267
301
  await self._cache_tracker_ref.record_model_version(
268
302
  model_version_infos, self.address
269
303
  )
@@ -514,17 +548,23 @@ class WorkerActor(xo.StatelessActor):
514
548
  raise ValueError(f"{model_name} model can't run on Darwin system.")
515
549
 
516
550
  @log_sync(logger=logger)
517
- def register_model(self, model_type: str, model: str, persist: bool):
551
+ async def register_model(self, model_type: str, model: str, persist: bool):
518
552
  # TODO: centralized model registrations
519
553
  if model_type in self._custom_register_type_to_cls:
520
554
  (
521
555
  model_spec_cls,
522
556
  register_fn,
523
557
  unregister_fn,
558
+ generate_fn,
524
559
  ) = self._custom_register_type_to_cls[model_type]
525
560
  model_spec = model_spec_cls.parse_raw(model)
526
561
  try:
527
562
  register_fn(model_spec, persist)
563
+ await self._cache_tracker_ref.record_model_version(
564
+ generate_fn(model_spec), self.address
565
+ )
566
+ except ValueError as e:
567
+ raise e
528
568
  except Exception as e:
529
569
  unregister_fn(model_spec.model_name, raise_error=False)
530
570
  raise e
@@ -532,14 +572,127 @@ class WorkerActor(xo.StatelessActor):
532
572
  raise ValueError(f"Unsupported model type: {model_type}")
533
573
 
534
574
  @log_sync(logger=logger)
535
- def unregister_model(self, model_type: str, model_name: str):
575
+ async def unregister_model(self, model_type: str, model_name: str):
536
576
  # TODO: centralized model registrations
537
577
  if model_type in self._custom_register_type_to_cls:
538
- _, _, unregister_fn = self._custom_register_type_to_cls[model_type]
539
- unregister_fn(model_name)
578
+ _, _, unregister_fn, _ = self._custom_register_type_to_cls[model_type]
579
+ unregister_fn(model_name, False)
540
580
  else:
541
581
  raise ValueError(f"Unsupported model type: {model_type}")
542
582
 
583
+ @log_async(logger=logger)
584
+ async def list_model_registrations(
585
+ self, model_type: str, detailed: bool = False
586
+ ) -> List[Dict[str, Any]]:
587
+ def sort_helper(item):
588
+ assert isinstance(item["model_name"], str)
589
+ return item.get("model_name").lower()
590
+
591
+ if model_type == "LLM":
592
+ from ..model.llm import get_user_defined_llm_families
593
+
594
+ ret = []
595
+
596
+ for family in get_user_defined_llm_families():
597
+ ret.append({"model_name": family.model_name, "is_builtin": False})
598
+
599
+ ret.sort(key=sort_helper)
600
+ return ret
601
+ elif model_type == "embedding":
602
+ from ..model.embedding.custom import get_user_defined_embeddings
603
+
604
+ ret = []
605
+
606
+ for model_spec in get_user_defined_embeddings():
607
+ ret.append({"model_name": model_spec.model_name, "is_builtin": False})
608
+
609
+ ret.sort(key=sort_helper)
610
+ return ret
611
+ elif model_type == "image":
612
+ from ..model.image.custom import get_user_defined_images
613
+
614
+ ret = []
615
+
616
+ for model_spec in get_user_defined_images():
617
+ ret.append({"model_name": model_spec.model_name, "is_builtin": False})
618
+
619
+ ret.sort(key=sort_helper)
620
+ return ret
621
+ elif model_type == "audio":
622
+ from ..model.audio.custom import get_user_defined_audios
623
+
624
+ ret = []
625
+
626
+ for model_spec in get_user_defined_audios():
627
+ ret.append({"model_name": model_spec.model_name, "is_builtin": False})
628
+
629
+ ret.sort(key=sort_helper)
630
+ return ret
631
+ elif model_type == "rerank":
632
+ from ..model.rerank.custom import get_user_defined_reranks
633
+
634
+ ret = []
635
+
636
+ for model_spec in get_user_defined_reranks():
637
+ ret.append({"model_name": model_spec.model_name, "is_builtin": False})
638
+
639
+ ret.sort(key=sort_helper)
640
+ return ret
641
+ else:
642
+ raise ValueError(f"Unsupported model type: {model_type}")
643
+
644
+ @log_sync(logger=logger)
645
+ async def get_model_registration(self, model_type: str, model_name: str) -> Any:
646
+ if model_type == "LLM":
647
+ from ..model.llm import get_user_defined_llm_families
648
+
649
+ for f in get_user_defined_llm_families():
650
+ if f.model_name == model_name:
651
+ return f
652
+ elif model_type == "embedding":
653
+ from ..model.embedding.custom import get_user_defined_embeddings
654
+
655
+ for f in get_user_defined_embeddings():
656
+ if f.model_name == model_name:
657
+ return f
658
+ elif model_type == "image":
659
+ from ..model.image.custom import get_user_defined_images
660
+
661
+ for f in get_user_defined_images():
662
+ if f.model_name == model_name:
663
+ return f
664
+ elif model_type == "audio":
665
+ from ..model.audio.custom import get_user_defined_audios
666
+
667
+ for f in get_user_defined_audios():
668
+ if f.model_name == model_name:
669
+ return f
670
+ elif model_type == "rerank":
671
+ from ..model.rerank.custom import get_user_defined_reranks
672
+
673
+ for f in get_user_defined_reranks():
674
+ if f.model_name == model_name:
675
+ return f
676
+ return None
677
+
678
+ @log_async(logger=logger)
679
+ async def query_engines_by_model_name(self, model_name: str):
680
+ from copy import deepcopy
681
+
682
+ from ..model.llm.llm_family import LLM_ENGINES
683
+
684
+ if model_name not in LLM_ENGINES:
685
+ return None
686
+
687
+ # filter llm_class
688
+ engine_params = deepcopy(LLM_ENGINES[model_name])
689
+ for engine in engine_params:
690
+ params = engine_params[engine]
691
+ for param in params:
692
+ del param["llm_class"]
693
+
694
+ return engine_params
695
+
543
696
  async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
544
697
  from ..model.llm.core import LLM
545
698
 
@@ -551,6 +704,8 @@ class WorkerActor(xo.StatelessActor):
551
704
  return ["text_to_image"]
552
705
  elif model_type == "audio":
553
706
  return ["audio_to_text"]
707
+ elif model_type == "flexible":
708
+ return ["flexible"]
554
709
  else:
555
710
  assert model_type == "LLM"
556
711
  assert isinstance(model, LLM)
@@ -587,6 +742,7 @@ class WorkerActor(xo.StatelessActor):
587
742
  peft_model_config: Optional[PeftModelConfig] = None,
588
743
  request_limits: Optional[int] = None,
589
744
  gpu_idx: Optional[Union[int, List[int]]] = None,
745
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
590
746
  **kwargs,
591
747
  ):
592
748
  # !!! Note that The following code must be placed at the very beginning of this function,
@@ -669,6 +825,7 @@ class WorkerActor(xo.StatelessActor):
669
825
  model_size_in_billions,
670
826
  quantization,
671
827
  peft_model_config,
828
+ download_hub,
672
829
  **kwargs,
673
830
  )
674
831
  await self.update_cache_status(model_name, model_description)
@@ -370,6 +370,9 @@ def worker(
370
370
  help="Type of model to register (default is 'LLM').",
371
371
  )
372
372
  @click.option("--file", "-f", type=str, help="Path to the model configuration file.")
373
+ @click.option(
374
+ "--worker-ip", "-w", type=str, help="Specify the ip address of the worker."
375
+ )
373
376
  @click.option(
374
377
  "--persist",
375
378
  "-p",
@@ -387,6 +390,7 @@ def register_model(
387
390
  endpoint: Optional[str],
388
391
  model_type: str,
389
392
  file: str,
393
+ worker_ip: str,
390
394
  persist: bool,
391
395
  api_key: Optional[str],
392
396
  ):
@@ -400,6 +404,7 @@ def register_model(
400
404
  client.register_model(
401
405
  model_type=model_type,
402
406
  model=model,
407
+ worker_ip=worker_ip,
403
408
  persist=persist,
404
409
  )
405
410
 
@@ -38,21 +38,24 @@ class ChatTTSModel:
38
38
  self._kwargs = kwargs
39
39
 
40
40
  def load(self):
41
+ import ChatTTS
41
42
  import torch
42
43
 
43
- from xinference.thirdparty import ChatTTS
44
-
45
44
  torch._dynamo.config.cache_size_limit = 64
46
45
  torch._dynamo.config.suppress_errors = True
47
46
  torch.set_float32_matmul_precision("high")
48
47
  self._model = ChatTTS.Chat()
49
- self._model.load_models(
50
- source="local", local_path=self._model_path, compile=True
51
- )
48
+ self._model.load(source="custom", custom_path=self._model_path, compile=True)
52
49
 
53
50
  def speech(
54
- self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
51
+ self,
52
+ input: str,
53
+ voice: str,
54
+ response_format: str = "mp3",
55
+ speed: float = 1.0,
56
+ stream: bool = False,
55
57
  ):
58
+ import ChatTTS
56
59
  import numpy as np
57
60
  import torch
58
61
  import torchaudio
@@ -71,14 +74,43 @@ class ChatTTSModel:
71
74
 
72
75
  default = 5
73
76
  infer_speed = int(default * speed)
74
- params_infer_code = {"spk_emb": rnd_spk_emb, "prompt": f"[speed_{infer_speed}]"}
77
+ params_infer_code = ChatTTS.Chat.InferCodeParams(
78
+ prompt=f"[speed_{infer_speed}]", spk_emb=rnd_spk_emb
79
+ )
75
80
 
76
81
  assert self._model is not None
77
- wavs = self._model.infer([input], params_infer_code=params_infer_code)
78
-
79
- # Save the generated audio
80
- with BytesIO() as out:
81
- torchaudio.save(
82
- out, torch.from_numpy(wavs[0]), 24000, format=response_format
82
+ if stream:
83
+ iter = self._model.infer(
84
+ [input], params_infer_code=params_infer_code, stream=True
83
85
  )
84
- return out.getvalue()
86
+
87
+ def _generator():
88
+ with BytesIO() as out:
89
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
90
+ writer.add_audio_stream(sample_rate=24000, num_channels=1)
91
+ i = 0
92
+ last_pos = 0
93
+ with writer.open():
94
+ for it in iter:
95
+ for itt in it:
96
+ for chunk in itt:
97
+ chunk = np.array([chunk]).transpose()
98
+ writer.write_audio_chunk(i, torch.from_numpy(chunk))
99
+ new_last_pos = out.tell()
100
+ if new_last_pos != last_pos:
101
+ out.seek(last_pos)
102
+ encoded_bytes = out.read()
103
+ print(len(encoded_bytes))
104
+ yield encoded_bytes
105
+ last_pos = new_last_pos
106
+
107
+ return _generator()
108
+ else:
109
+ wavs = self._model.infer([input], params_infer_code=params_infer_code)
110
+
111
+ # Save the generated audio
112
+ with BytesIO() as out:
113
+ torchaudio.save(
114
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
115
+ )
116
+ return out.getvalue()