xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-08-03T12:12:02+0800",
11
+ "date": "2025-08-30T03:57:39+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2adec1c027044a920632ad7626f8f278eef83361",
15
- "version": "1.8.1.rc1"
14
+ "full-revisionid": "b2d793d0b4a0af632932eb63dbeb1bc91b5b3d74",
15
+ "version": "1.9.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -2249,8 +2249,9 @@ class RESTfulAPI(CancelMixin):
2249
2249
  )
2250
2250
  if body.tools and body.stream:
2251
2251
  is_vllm = await model.is_vllm_backend()
2252
+ is_sglang = await model.is_sglang_backend()
2252
2253
  if not (
2253
- (is_vllm and model_family in QWEN_TOOL_CALL_FAMILY)
2254
+ ((is_vllm or is_sglang) and model_family in QWEN_TOOL_CALL_FAMILY)
2254
2255
  or (not is_vllm and model_family in GLM4_TOOL_CALL_FAMILY)
2255
2256
  ):
2256
2257
  raise HTTPException(
xinference/core/model.py CHANGED
@@ -365,6 +365,11 @@ class ModelActor(xo.StatelessActor, CancelMixin):
365
365
 
366
366
  return isinstance(self._model, VLLMModel)
367
367
 
368
+ def is_sglang_backend(self) -> bool:
369
+ from ..model.llm.sglang.core import SGLANGModel
370
+
371
+ return isinstance(self._model, SGLANGModel)
372
+
368
373
  async def load(self):
369
374
  try:
370
375
  # Change process title for model
@@ -877,10 +882,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
877
882
  **kwargs,
878
883
  ):
879
884
  if hasattr(self._model, "text_to_image"):
880
- # Directly delegate to model, let model decide how to handle (batching or not)
881
- progressor = kwargs["progressor"] = await self._get_progressor(
882
- kwargs.pop("request_id", None)
883
- )
885
+ # Get progressor (don't pop request_id, let _call_wrapper handle cancellation)
886
+ request_id = kwargs.get("request_id")
887
+ progressor = kwargs["progressor"] = await self._get_progressor(request_id) # type: ignore
884
888
  with progressor:
885
889
  return await self._call_wrapper_json(
886
890
  self._model.text_to_image,
@@ -476,7 +476,7 @@ class SupervisorActor(xo.StatelessActor):
476
476
  async def _to_rerank_model_reg(
477
477
  self, model_spec: "RerankModelFamilyV2", is_builtin: bool
478
478
  ) -> Dict[str, Any]:
479
- from ..model.cache_manager import CacheManager
479
+ from ..model.rerank.cache_manager import RerankCacheManager as CacheManager
480
480
 
481
481
  instance_cnt = await self.get_instance_count(model_spec.model_name)
482
482
  version_cnt = await self.get_model_version_count(model_spec.model_name)
@@ -712,9 +712,8 @@ class SupervisorActor(xo.StatelessActor):
712
712
  from ..model.rerank import BUILTIN_RERANK_MODELS
713
713
  from ..model.rerank.custom import get_user_defined_reranks
714
714
 
715
- for model_name, families in BUILTIN_RERANK_MODELS.items():
715
+ for model_name, family in BUILTIN_RERANK_MODELS.items():
716
716
  if detailed:
717
- family = [x for x in families if x.model_hub == "huggingface"][0]
718
717
  ret.append(await self._to_rerank_model_reg(family, is_builtin=True))
719
718
  else:
720
719
  ret.append({"model_name": model_name, "is_builtin": True})
xinference/core/worker.py CHANGED
@@ -817,10 +817,7 @@ class WorkerActor(xo.StatelessActor):
817
817
  # we specify python_path explicitly
818
818
  # sometimes uv would find other versions.
819
819
  python_path = pathlib.Path(sys.executable)
820
- kw = {}
821
- if XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED:
822
- kw["skip_installed"] = XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED
823
- virtual_env_manager.create_env(python_path=python_path, **kw)
820
+ virtual_env_manager.create_env(python_path=python_path)
824
821
  return virtual_env_manager
825
822
 
826
823
  @classmethod
@@ -830,10 +827,13 @@ class WorkerActor(xo.StatelessActor):
830
827
  settings: Optional[VirtualEnvSettings],
831
828
  virtual_env_packages: Optional[List[str]],
832
829
  ):
833
- if not settings or not settings.packages:
830
+ if (not settings or not settings.packages) and not virtual_env_packages:
834
831
  # no settings or no packages
835
832
  return
836
833
 
834
+ if settings is None:
835
+ settings = VirtualEnvSettings(packages=virtual_env_packages)
836
+
837
837
  if settings.inherit_pip_config:
838
838
  # inherit pip config
839
839
  pip_config = get_pip_config_args()
@@ -847,6 +847,8 @@ class WorkerActor(xo.StatelessActor):
847
847
  packages.extend(virtual_env_packages)
848
848
  conf.pop("packages", None)
849
849
  conf.pop("inherit_pip_config", None)
850
+ if XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED:
851
+ conf["skip_installed"] = XINFERENCE_VIRTUAL_ENV_SKIP_INSTALLED
850
852
 
851
853
  logger.info(
852
854
  "Installing packages %s in virtual env %s, with settings(%s)",
@@ -1345,6 +1345,8 @@ def model_chat(
1345
1345
  messages,
1346
1346
  generate_config={"stream": stream, "max_tokens": max_tokens},
1347
1347
  ):
1348
+ if not chunk["choices"]:
1349
+ continue
1348
1350
  delta = chunk["choices"][0]["delta"]
1349
1351
  if "content" not in delta:
1350
1352
  continue
@@ -152,6 +152,11 @@ def main(
152
152
  logging_conf: Optional[Dict] = None,
153
153
  auth_config_file: Optional[str] = None,
154
154
  ):
155
+ # force to set spawn,
156
+ # cuda may be inited in xoscar virtualenv
157
+ # which will raise error after sub pool is created
158
+ multiprocessing.set_start_method("spawn")
159
+
155
160
  supervisor_address = f"{host}:{get_next_port()}"
156
161
  local_cluster = run_in_subprocess(
157
162
  supervisor_address, metrics_exporter_host, metrics_exporter_port, logging_conf
@@ -87,7 +87,7 @@ def test_cmdline(setup, stream, model_uid):
87
87
  ],
88
88
  )
89
89
  assert result.exit_code == 0
90
- assert model_uid in result.stdout
90
+ assert model_uid in result.output
91
91
 
92
92
  # model generate
93
93
  result = runner.invoke(
@@ -14,6 +14,7 @@
14
14
 
15
15
  import asyncio
16
16
  import logging
17
+ import multiprocessing
17
18
  import os
18
19
  from typing import Any, Optional
19
20
 
@@ -81,6 +82,11 @@ def main(
81
82
  metrics_exporter_port: Optional[int] = None,
82
83
  logging_conf: Optional[dict] = None,
83
84
  ):
85
+ # force to set spawn,
86
+ # cuda may be inited in xoscar virtualenv
87
+ # which will raise error after sub pool is created
88
+ multiprocessing.set_start_method("spawn")
89
+
84
90
  loop = asyncio.get_event_loop()
85
91
  task = loop.create_task(
86
92
  _start_worker(
@@ -60,7 +60,6 @@ class CosyVoiceModel:
60
60
  from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
61
61
 
62
62
  self._is_cosyvoice2 = True
63
- kwargs = {"use_flow_cache": self._kwargs.get("use_flow_cache", False)}
64
63
  else:
65
64
  from cosyvoice.cli.cosyvoice import CosyVoice
66
65
 
@@ -525,7 +525,8 @@
525
525
  "model_name": "ChatTTS",
526
526
  "model_family": "ChatTTS",
527
527
  "model_ability": [
528
- "text2audio"
528
+ "text2audio",
529
+ "text2audio_zero_shot"
529
530
  ],
530
531
  "multilingual": true,
531
532
  "virtualenv": {
@@ -551,7 +552,8 @@
551
552
  "model_name": "CosyVoice-300M",
552
553
  "model_family": "CosyVoice",
553
554
  "model_ability": [
554
- "text2audio"
555
+ "text2audio",
556
+ "text2audio_voice_cloning"
555
557
  ],
556
558
  "multilingual": true,
557
559
  "model_src": {
@@ -570,7 +572,8 @@
570
572
  "model_name": "CosyVoice-300M-SFT",
571
573
  "model_family": "CosyVoice",
572
574
  "model_ability": [
573
- "text2audio"
575
+ "text2audio",
576
+ "text2audio_zero_shot"
574
577
  ],
575
578
  "multilingual": true,
576
579
  "model_src": {
@@ -589,7 +592,8 @@
589
592
  "model_name": "CosyVoice-300M-Instruct",
590
593
  "model_family": "CosyVoice",
591
594
  "model_ability": [
592
- "text2audio"
595
+ "text2audio",
596
+ "text2audio_zero_shot"
593
597
  ],
594
598
  "multilingual": true,
595
599
  "model_src": {
@@ -608,7 +612,9 @@
608
612
  "model_name": "CosyVoice2-0.5B",
609
613
  "model_family": "CosyVoice",
610
614
  "model_ability": [
611
- "text2audio"
615
+ "text2audio",
616
+ "text2audio_zero_shot",
617
+ "text2audio_voice_cloning"
612
618
  ],
613
619
  "multilingual": true,
614
620
  "virtualenv": {
@@ -625,7 +631,8 @@
625
631
  "HyperPyYAML",
626
632
  "onnxruntime>=1.16.0",
627
633
  "pyworld>=0.3.4",
628
- "WeTextProcessing<1.0.4",
634
+ "wetext==0.0.9",
635
+ "transformers==4.51.3",
629
636
  "#system_numpy#",
630
637
  "#system_torch#"
631
638
  ]
@@ -646,7 +653,9 @@
646
653
  "model_name": "FishSpeech-1.5",
647
654
  "model_family": "FishAudio",
648
655
  "model_ability": [
649
- "text2audio"
656
+ "text2audio",
657
+ "text2audio_zero_shot",
658
+ "text2audio_voice_cloning"
650
659
  ],
651
660
  "multilingual": true,
652
661
  "model_src": {
@@ -665,7 +674,9 @@
665
674
  "model_name": "F5-TTS",
666
675
  "model_family": "F5-TTS",
667
676
  "model_ability": [
668
- "text2audio"
677
+ "text2audio",
678
+ "text2audio_zero_shot",
679
+ "text2audio_voice_cloning"
669
680
  ],
670
681
  "multilingual": true,
671
682
  "model_src": {
@@ -684,7 +695,9 @@
684
695
  "model_name": "F5-TTS-MLX",
685
696
  "model_family": "F5-TTS-MLX",
686
697
  "model_ability": [
687
- "text2audio"
698
+ "text2audio",
699
+ "text2audio_zero_shot",
700
+ "text2audio_voice_cloning"
688
701
  ],
689
702
  "multilingual": true,
690
703
  "model_src": {
@@ -699,7 +712,8 @@
699
712
  "model_name": "MeloTTS-English",
700
713
  "model_family": "MeloTTS",
701
714
  "model_ability": [
702
- "text2audio"
715
+ "text2audio",
716
+ "text2audio_zero_shot"
703
717
  ],
704
718
  "multilingual": false,
705
719
  "language": "EN",
@@ -715,7 +729,8 @@
715
729
  "model_name": "MeloTTS-English-v2",
716
730
  "model_family": "MeloTTS",
717
731
  "model_ability": [
718
- "text2audio"
732
+ "text2audio",
733
+ "text2audio_zero_shot"
719
734
  ],
720
735
  "multilingual": false,
721
736
  "language": "EN",
@@ -731,7 +746,8 @@
731
746
  "model_name": "MeloTTS-English-v3",
732
747
  "model_family": "MeloTTS",
733
748
  "model_ability": [
734
- "text2audio"
749
+ "text2audio",
750
+ "text2audio_zero_shot"
735
751
  ],
736
752
  "multilingual": false,
737
753
  "language": "EN",
@@ -747,7 +763,8 @@
747
763
  "model_name": "MeloTTS-French",
748
764
  "model_family": "MeloTTS",
749
765
  "model_ability": [
750
- "text2audio"
766
+ "text2audio",
767
+ "text2audio_zero_shot"
751
768
  ],
752
769
  "multilingual": false,
753
770
  "language": "FR",
@@ -763,7 +780,8 @@
763
780
  "model_name": "MeloTTS-Japanese",
764
781
  "model_family": "MeloTTS",
765
782
  "model_ability": [
766
- "text2audio"
783
+ "text2audio",
784
+ "text2audio_zero_shot"
767
785
  ],
768
786
  "multilingual": false,
769
787
  "language": "JP",
@@ -779,7 +797,8 @@
779
797
  "model_name": "MeloTTS-Spanish",
780
798
  "model_family": "MeloTTS",
781
799
  "model_ability": [
782
- "text2audio"
800
+ "text2audio",
801
+ "text2audio_zero_shot"
783
802
  ],
784
803
  "multilingual": false,
785
804
  "language": "ES",
@@ -795,7 +814,8 @@
795
814
  "model_name": "MeloTTS-Chinese",
796
815
  "model_family": "MeloTTS",
797
816
  "model_ability": [
798
- "text2audio"
817
+ "text2audio",
818
+ "text2audio_zero_shot"
799
819
  ],
800
820
  "multilingual": false,
801
821
  "language": "ZH",
@@ -811,7 +831,8 @@
811
831
  "model_name": "MeloTTS-Korean",
812
832
  "model_family": "MeloTTS",
813
833
  "model_ability": [
814
- "text2audio"
834
+ "text2audio",
835
+ "text2audio_zero_shot"
815
836
  ],
816
837
  "multilingual": false,
817
838
  "language": "KR",
@@ -827,7 +848,8 @@
827
848
  "model_name": "Kokoro-82M",
828
849
  "model_family": "Kokoro",
829
850
  "model_ability": [
830
- "text2audio"
851
+ "text2audio",
852
+ "text2audio_zero_shot"
831
853
  ],
832
854
  "multilingual": true,
833
855
  "model_src": {
@@ -846,7 +868,8 @@
846
868
  "model_name": "Kokoro-82M-MLX",
847
869
  "model_family": "Kokoro-MLX",
848
870
  "model_ability": [
849
- "text2audio"
871
+ "text2audio",
872
+ "text2audio_zero_shot"
850
873
  ],
851
874
  "multilingual": true,
852
875
  "model_src": {
@@ -874,7 +897,8 @@
874
897
  "model_name": "MegaTTS3",
875
898
  "model_family": "MegaTTS",
876
899
  "model_ability": [
877
- "text2audio"
900
+ "text2audio",
901
+ "text2audio_zero_shot"
878
902
  ],
879
903
  "multilingual": true,
880
904
  "model_src": {
xinference/model/core.py CHANGED
@@ -81,6 +81,9 @@ def create_model_instance(
81
81
  return create_rerank_model_instance(
82
82
  model_uid,
83
83
  model_name,
84
+ model_engine,
85
+ model_format,
86
+ quantization,
84
87
  download_hub,
85
88
  model_path,
86
89
  **kwargs,
@@ -58,6 +58,11 @@ class FlagEmbeddingModel(EmbeddingModel):
58
58
  self._return_sparse = return_sparse
59
59
 
60
60
  def load(self):
61
+ # add truncate_dim args hint
62
+ if self._kwargs and "dimensions" in self._kwargs:
63
+ raise NotImplementedError(
64
+ "Flag embedder does not support dimensions argument now."
65
+ )
61
66
  try:
62
67
  from FlagEmbedding import BGEM3FlagModel
63
68
  except ImportError:
@@ -22,7 +22,7 @@ import queue
22
22
  import sys
23
23
  from typing import List, Optional, Union
24
24
 
25
- import orjson
25
+ from packaging import version
26
26
 
27
27
  from ....types import Embedding
28
28
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
@@ -69,15 +69,29 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
69
69
  return sys.platform.startswith("linux")
70
70
 
71
71
  def load(self):
72
+ # add truncate_dim args hint
73
+ if "dimensions" in self._kwargs:
74
+ raise NotImplementedError(
75
+ "LlamaCpp embedder does not support dimensions argument now."
76
+ )
72
77
  try:
73
78
  from xllamacpp import (
74
79
  CommonParams,
75
80
  Server,
81
+ __version__,
76
82
  estimate_gpu_layers,
77
83
  get_device_info,
78
84
  ggml_backend_dev_type,
79
85
  llama_pooling_type,
80
86
  )
87
+
88
+ try:
89
+ if version.parse(__version__) < version.parse("0.2.0"):
90
+ raise RuntimeError(
91
+ "Please update xllamacpp to >= 0.2.0 by `pip install -U xllamacpp`"
92
+ )
93
+ except version.InvalidVersion:
94
+ pass # If the version parse failed, we just skip the version check.
81
95
  except ImportError:
82
96
  error_message = "Failed to import module 'xllamacpp'"
83
97
  installation_guide = ["Please make sure 'xllamacpp' is installed. "]
@@ -162,7 +176,8 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
162
176
  )
163
177
  logger.info("Estimate num gpu layers: %s", estimate)
164
178
  if estimate.tensor_split:
165
- params.tensor_split = estimate.tensor_split
179
+ for i in range(len(estimate.tensor_split)):
180
+ params.tensor_split[i] = estimate.tensor_split[i]
166
181
  else:
167
182
  params.n_gpu_layers = estimate.layers
168
183
  except Exception as e:
@@ -190,24 +205,12 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
190
205
  model_uid: Optional[str] = kwargs.pop("model_uid", None)
191
206
  if model_uid:
192
207
  data["model"] = model_uid
193
- prompt_json = orjson.dumps(data)
194
-
195
- def _error_callback(err):
196
- try:
197
- msg = orjson.loads(err)
198
- q.put(_Error(msg))
199
- except Exception as e:
200
- q.put(_Error(str(e)))
201
-
202
- def _ok_callback(ok):
203
- try:
204
- res = orjson.loads(ok)
205
- q.put(res)
206
- except Exception as e:
207
- q.put(_Error(str(e)))
208
-
209
208
  try:
210
- self._llm.handle_embeddings(prompt_json, _error_callback, _ok_callback)
209
+ res = self._llm.handle_embeddings(data)
210
+ if res.get("code"):
211
+ q.put(_Error(res))
212
+ else:
213
+ q.put(res)
211
214
  except Exception as ex:
212
215
  q.put(_Error(str(ex)))
213
216
  q.put(_Done)
@@ -19,8 +19,8 @@ from typing import List, Optional, Union, no_type_check
19
19
  import numpy as np
20
20
  import torch
21
21
 
22
- from ....device_utils import is_device_available
23
22
  from ....types import Embedding, EmbeddingData, EmbeddingUsage
23
+ from ...utils import is_flash_attn_available
24
24
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
25
25
 
26
26
  logger = logging.getLogger(__name__)
@@ -71,6 +71,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
71
71
  )
72
72
  torch_dtype = torch.float32
73
73
 
74
+ dimensions = self._kwargs.get("dimensions")
75
+ assert dimensions is None or isinstance(dimensions, int), (
76
+ "The `dimensions` argument must be an integer, "
77
+ f"but got {type(dimensions)}: {dimensions}"
78
+ )
79
+
74
80
  if (
75
81
  "gte" in self.model_family.model_name.lower()
76
82
  and "qwen2" in self.model_family.model_name.lower()
@@ -82,16 +88,16 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
82
88
  self._model_path,
83
89
  device=self._device,
84
90
  model_kwargs=model_kwargs,
91
+ truncate_dim=dimensions,
85
92
  )
86
93
  elif "qwen3" in self.model_family.model_name.lower():
87
94
  # qwen3 embedding
88
- flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
89
95
  flash_attn_enabled = self._kwargs.get(
90
- "enable_flash_attn", is_device_available("cuda")
96
+ "enable_flash_attn", is_flash_attn_available()
91
97
  )
92
98
  model_kwargs = {"device_map": "auto"}
93
99
  tokenizer_kwargs = {}
94
- if flash_attn_installed and flash_attn_enabled:
100
+ if flash_attn_enabled:
95
101
  model_kwargs["attn_implementation"] = "flash_attention_2"
96
102
  model_kwargs["torch_dtype"] = "bfloat16"
97
103
  tokenizer_kwargs["padding_side"] = "left"
@@ -107,6 +113,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
107
113
  device=self._device,
108
114
  model_kwargs=model_kwargs,
109
115
  tokenizer_kwargs=tokenizer_kwargs,
116
+ truncate_dim=dimensions,
110
117
  )
111
118
  else:
112
119
  model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
@@ -115,6 +122,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
115
122
  device=self._device,
116
123
  model_kwargs=model_kwargs,
117
124
  trust_remote_code=True,
125
+ truncate_dim=dimensions,
118
126
  )
119
127
 
120
128
  if hasattr(self._model, "tokenizer"):
@@ -271,6 +279,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
271
279
  with torch.no_grad():
272
280
  out_features = model.forward(features, **kwargs)
273
281
 
282
+ from sentence_transformers.util import truncate_embeddings
283
+
284
+ out_features["sentence_embedding"] = truncate_embeddings(
285
+ out_features["sentence_embedding"], model.truncate_dim
286
+ )
287
+
274
288
  if output_value == "token_embeddings":
275
289
  embeddings = []
276
290
  for token_emb, attention in zip(
@@ -17,6 +17,7 @@ import logging
17
17
  from typing import List, Union
18
18
 
19
19
  from ....types import Embedding, EmbeddingData, EmbeddingUsage
20
+ from ...utils import cache_clean
20
21
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
21
22
 
22
23
  logger = logging.getLogger(__name__)
@@ -24,7 +25,6 @@ SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
24
25
 
25
26
 
26
27
  class VLLMEmbeddingModel(EmbeddingModel):
27
-
28
28
  def __init__(self, *args, **kwargs):
29
29
  super().__init__(*args, **kwargs)
30
30
  self._context_length = None
@@ -41,28 +41,42 @@ class VLLMEmbeddingModel(EmbeddingModel):
41
41
  ]
42
42
 
43
43
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
44
+ if self.model_family.model_name in {
45
+ "Qwen3-Embedding-0.6B",
46
+ "Qwen3-Embedding-4B",
47
+ "Qwen3-Embedding-8B",
48
+ }:
49
+ if "hf_overrides" not in self._kwargs:
50
+ self._kwargs["hf_overrides"] = {
51
+ "is_matryoshka": True,
52
+ }
53
+ elif isinstance(self._kwargs["hf_overrides"], dict):
54
+ self._kwargs["hf_overrides"].update(
55
+ is_matryoshka=True,
56
+ )
44
57
 
45
- self._model = LLM(model=self._model_path, task="embed")
58
+ self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
46
59
  self._tokenizer = self._model.get_tokenizer()
47
60
 
48
61
  @staticmethod
49
62
  def _get_detailed_instruct(task_description: str, query: str) -> str:
50
63
  return f"Instruct: {task_description}\nQuery:{query}"
51
64
 
65
+ @cache_clean
52
66
  def create_embedding(
53
67
  self,
54
68
  sentences: Union[str, List[str]],
55
69
  **kwargs,
56
70
  ):
71
+ from packaging.version import Version
72
+ from vllm import PoolingParams
73
+ from vllm import __version__ as vllm_version
74
+
57
75
  sentences = self._fix_langchain_openai_inputs(sentences)
58
76
  model_uid = kwargs.pop("model_uid", None)
59
77
 
60
78
  normalize_embedding = kwargs.get("normalize_embedding", True)
61
- if not normalize_embedding:
62
- raise ValueError(
63
- "vllm embedding engine does not support "
64
- "setting `normalize_embedding=False`"
65
- )
79
+ dimensions = kwargs.get("dimensions", None)
66
80
 
67
81
  assert self._model is not None
68
82
 
@@ -91,8 +105,21 @@ class VLLMEmbeddingModel(EmbeddingModel):
91
105
  sentences = truncated_sentences[0]
92
106
  else:
93
107
  sentences = truncated_sentences
94
-
95
- outputs = self._model.embed(sentences, use_tqdm=False)
108
+ if Version(vllm_version) > Version("0.10.1"):
109
+ pool_params = PoolingParams(
110
+ dimensions=dimensions, normalize=normalize_embedding
111
+ )
112
+ else:
113
+ if not normalize_embedding:
114
+ raise ValueError(
115
+ f"vLLM version {vllm_version} does not support "
116
+ f"unnormalized embeddings. "
117
+ f"Please upgrade to v0.10.1 or later."
118
+ )
119
+ pool_params = PoolingParams(dimensions=dimensions)
120
+ outputs = self._model.embed(
121
+ sentences, use_tqdm=False, pooling_params=pool_params
122
+ )
96
123
  embedding_list = []
97
124
  all_token_nums = 0
98
125
  for index, output in enumerate(outputs):