xinference 1.9.0__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (92) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +415 -1
  3. xinference/constants.py +2 -0
  4. xinference/core/model.py +3 -4
  5. xinference/core/supervisor.py +29 -1
  6. xinference/core/worker.py +4 -1
  7. xinference/deploy/cmdline.py +2 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/model/audio/core.py +5 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +64 -20
  14. xinference/model/embedding/flag/core.py +5 -0
  15. xinference/model/embedding/llama_cpp/core.py +22 -19
  16. xinference/model/embedding/sentence_transformers/core.py +19 -4
  17. xinference/model/embedding/vllm/core.py +40 -8
  18. xinference/model/image/cache_manager.py +56 -0
  19. xinference/model/image/core.py +9 -0
  20. xinference/model/image/model_spec.json +116 -9
  21. xinference/model/image/stable_diffusion/core.py +141 -31
  22. xinference/model/llm/core.py +10 -0
  23. xinference/model/llm/llama_cpp/core.py +42 -40
  24. xinference/model/llm/llm_family.json +435 -23
  25. xinference/model/llm/llm_family.py +1 -0
  26. xinference/model/llm/mlx/core.py +52 -33
  27. xinference/model/llm/sglang/core.py +2 -44
  28. xinference/model/llm/tool_parsers/__init__.py +58 -0
  29. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  30. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +128 -0
  31. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  32. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  33. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  34. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  35. xinference/model/llm/transformers/core.py +6 -12
  36. xinference/model/llm/utils.py +128 -46
  37. xinference/model/llm/vllm/core.py +8 -61
  38. xinference/model/rerank/core.py +3 -0
  39. xinference/model/rerank/sentence_transformers/core.py +1 -1
  40. xinference/model/rerank/vllm/core.py +56 -6
  41. xinference/model/utils.py +1 -2
  42. xinference/model/video/model_spec.json +95 -1
  43. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  44. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  45. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  46. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  47. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  48. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  49. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  50. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  51. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  52. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  53. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  54. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  55. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  56. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  57. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  58. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  59. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  60. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  61. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  62. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  63. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  64. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  65. xinference/types.py +105 -2
  66. xinference/ui/gradio/chat_interface.py +2 -0
  67. xinference/ui/gradio/media_interface.py +353 -7
  68. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  69. xinference/ui/web/ui/build/index.html +1 -1
  70. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  71. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  72. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  73. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  74. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  75. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  76. xinference/ui/web/ui/src/locales/en.json +2 -0
  77. xinference/ui/web/ui/src/locales/ja.json +2 -0
  78. xinference/ui/web/ui/src/locales/ko.json +2 -0
  79. xinference/ui/web/ui/src/locales/zh.json +2 -0
  80. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/METADATA +16 -12
  81. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/RECORD +86 -77
  82. xinference/ui/web/ui/build/static/js/main.4918643a.js +0 -3
  83. xinference/ui/web/ui/build/static/js/main.4918643a.js.map +0 -1
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +0 -1
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +0 -1
  88. /xinference/ui/web/ui/build/static/js/{main.4918643a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  89. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/WHEEL +0 -0
  90. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/entry_points.txt +0 -0
  91. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/licenses/LICENSE +0 -0
  92. {xinference-1.9.0.dist-info → xinference-1.10.0.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,6 @@ class CosyVoiceModel:
60
60
  from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
61
61
 
62
62
  self._is_cosyvoice2 = True
63
- kwargs = {"use_flow_cache": self._kwargs.get("use_flow_cache", False)}
64
63
  else:
65
64
  from cosyvoice.cli.cosyvoice import CosyVoice
66
65
 
@@ -81,7 +81,7 @@ class KokoroModel:
81
81
  logger.info("Launching Kokoro model with language code: %s", lang_code)
82
82
  self._model = KPipeline(
83
83
  lang_code=lang_code,
84
- model=KModel(config=config_path, model=model_path),
84
+ model=KModel(config=config_path, model=model_path).to(self._device),
85
85
  device=self._device,
86
86
  )
87
87
 
@@ -0,0 +1,124 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from io import BytesIO
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ import numpy as np
19
+
20
+ from ...device_utils import get_available_device, is_device_available
21
+
22
+ if TYPE_CHECKING:
23
+ from .core import AudioModelFamilyV2
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ REPO_ID = "hexgrad/Kokoro-82M-v1.1-zh"
28
+
29
+
30
+ class KokoroZHModel:
31
+ def __init__(
32
+ self,
33
+ model_uid: str,
34
+ model_path: str,
35
+ model_spec: "AudioModelFamilyV2",
36
+ device: Optional[str] = None,
37
+ **kwargs,
38
+ ):
39
+ self.model_family = model_spec
40
+ self._model_uid = model_uid
41
+ self._model_path = model_path
42
+ self._model_spec = model_spec
43
+ self._device = device
44
+ self._model = None
45
+ self._kwargs = kwargs
46
+ self._en_pipeline = None
47
+
48
+ def _en_callable(self, text):
49
+ """
50
+ Fixing the issue of English words being skipped in the Chinese model.
51
+ from https://hf-mirror.com/hexgrad/Kokoro-82M-v1.1-zh/blob/main/samples/make_zh.py
52
+ """
53
+ if text == "Kokoro":
54
+ return "kˈOkəɹO"
55
+ elif text == "Sol":
56
+ return "sˈOl"
57
+ return next(self._en_pipeline(text)).phonemes
58
+
59
+ @property
60
+ def model_ability(self):
61
+ return self._model_spec.model_ability
62
+
63
+ def load(self):
64
+ if self._device is None:
65
+ self._device = get_available_device()
66
+ else:
67
+ if not is_device_available(self._device):
68
+ raise ValueError(f"Device {self._device} is not available!")
69
+
70
+ import os
71
+
72
+ from kokoro import KModel, KPipeline
73
+
74
+ self._en_pipeline = KPipeline(lang_code="a", repo_id=REPO_ID, model=False)
75
+
76
+ config_path = os.path.join(self._model_path, "config.json")
77
+ model_path = os.path.join(self._model_path, "kokoro-v1_1-zh.pth")
78
+ lang_code = self._kwargs.get("lang_code", "z")
79
+ logger.info("Launching Kokoro model with language code: %s", lang_code)
80
+
81
+ self._model = KPipeline(
82
+ lang_code=lang_code,
83
+ model=KModel(config=config_path, model=model_path).to(self._device),
84
+ repo_id=REPO_ID,
85
+ en_callable=self._en_callable,
86
+ device=self._device,
87
+ )
88
+
89
+ def speech(
90
+ self,
91
+ input: str,
92
+ voice: str,
93
+ response_format: str = "mp3",
94
+ speed: float = 1.0,
95
+ stream: bool = False,
96
+ **kwargs,
97
+ ):
98
+ import soundfile
99
+
100
+ if stream:
101
+ raise Exception("Kokoro does not support stream mode.")
102
+ assert self._model is not None
103
+ if not voice:
104
+ voice = "zf_001"
105
+ logger.info("Auto select speaker: %s", voice)
106
+ elif voice.endswith(".pt"):
107
+ logger.info("Using custom voice pt: %s", voice)
108
+ else:
109
+ logger.info("Using voice: %s", voice)
110
+ logger.info("Speech kwargs: %s", kwargs)
111
+ generator = self._model(text=input, voice=voice, speed=speed, **kwargs)
112
+ results = list(generator)
113
+ audio = np.concatenate([r[2] for r in results])
114
+ # Save the generated audio
115
+ with BytesIO() as out:
116
+ with soundfile.SoundFile(
117
+ out,
118
+ "w",
119
+ 24000,
120
+ 1,
121
+ format=response_format.upper(),
122
+ ) as f:
123
+ f.write(audio)
124
+ return out.getvalue()
@@ -525,7 +525,8 @@
525
525
  "model_name": "ChatTTS",
526
526
  "model_family": "ChatTTS",
527
527
  "model_ability": [
528
- "text2audio"
528
+ "text2audio",
529
+ "text2audio_zero_shot"
529
530
  ],
530
531
  "multilingual": true,
531
532
  "virtualenv": {
@@ -551,7 +552,8 @@
551
552
  "model_name": "CosyVoice-300M",
552
553
  "model_family": "CosyVoice",
553
554
  "model_ability": [
554
- "text2audio"
555
+ "text2audio",
556
+ "text2audio_voice_cloning"
555
557
  ],
556
558
  "multilingual": true,
557
559
  "model_src": {
@@ -570,7 +572,8 @@
570
572
  "model_name": "CosyVoice-300M-SFT",
571
573
  "model_family": "CosyVoice",
572
574
  "model_ability": [
573
- "text2audio"
575
+ "text2audio",
576
+ "text2audio_zero_shot"
574
577
  ],
575
578
  "multilingual": true,
576
579
  "model_src": {
@@ -589,7 +592,8 @@
589
592
  "model_name": "CosyVoice-300M-Instruct",
590
593
  "model_family": "CosyVoice",
591
594
  "model_ability": [
592
- "text2audio"
595
+ "text2audio",
596
+ "text2audio_zero_shot"
593
597
  ],
594
598
  "multilingual": true,
595
599
  "model_src": {
@@ -608,7 +612,9 @@
608
612
  "model_name": "CosyVoice2-0.5B",
609
613
  "model_family": "CosyVoice",
610
614
  "model_ability": [
611
- "text2audio"
615
+ "text2audio",
616
+ "text2audio_zero_shot",
617
+ "text2audio_voice_cloning"
612
618
  ],
613
619
  "multilingual": true,
614
620
  "virtualenv": {
@@ -625,7 +631,8 @@
625
631
  "HyperPyYAML",
626
632
  "onnxruntime>=1.16.0",
627
633
  "pyworld>=0.3.4",
628
- "WeTextProcessing<1.0.4",
634
+ "wetext==0.0.9",
635
+ "transformers==4.51.3",
629
636
  "#system_numpy#",
630
637
  "#system_torch#"
631
638
  ]
@@ -646,7 +653,9 @@
646
653
  "model_name": "FishSpeech-1.5",
647
654
  "model_family": "FishAudio",
648
655
  "model_ability": [
649
- "text2audio"
656
+ "text2audio",
657
+ "text2audio_zero_shot",
658
+ "text2audio_voice_cloning"
650
659
  ],
651
660
  "multilingual": true,
652
661
  "model_src": {
@@ -665,7 +674,9 @@
665
674
  "model_name": "F5-TTS",
666
675
  "model_family": "F5-TTS",
667
676
  "model_ability": [
668
- "text2audio"
677
+ "text2audio",
678
+ "text2audio_zero_shot",
679
+ "text2audio_voice_cloning"
669
680
  ],
670
681
  "multilingual": true,
671
682
  "model_src": {
@@ -684,7 +695,9 @@
684
695
  "model_name": "F5-TTS-MLX",
685
696
  "model_family": "F5-TTS-MLX",
686
697
  "model_ability": [
687
- "text2audio"
698
+ "text2audio",
699
+ "text2audio_zero_shot",
700
+ "text2audio_voice_cloning"
688
701
  ],
689
702
  "multilingual": true,
690
703
  "model_src": {
@@ -699,7 +712,8 @@
699
712
  "model_name": "MeloTTS-English",
700
713
  "model_family": "MeloTTS",
701
714
  "model_ability": [
702
- "text2audio"
715
+ "text2audio",
716
+ "text2audio_zero_shot"
703
717
  ],
704
718
  "multilingual": false,
705
719
  "language": "EN",
@@ -715,7 +729,8 @@
715
729
  "model_name": "MeloTTS-English-v2",
716
730
  "model_family": "MeloTTS",
717
731
  "model_ability": [
718
- "text2audio"
732
+ "text2audio",
733
+ "text2audio_zero_shot"
719
734
  ],
720
735
  "multilingual": false,
721
736
  "language": "EN",
@@ -731,7 +746,8 @@
731
746
  "model_name": "MeloTTS-English-v3",
732
747
  "model_family": "MeloTTS",
733
748
  "model_ability": [
734
- "text2audio"
749
+ "text2audio",
750
+ "text2audio_zero_shot"
735
751
  ],
736
752
  "multilingual": false,
737
753
  "language": "EN",
@@ -747,7 +763,8 @@
747
763
  "model_name": "MeloTTS-French",
748
764
  "model_family": "MeloTTS",
749
765
  "model_ability": [
750
- "text2audio"
766
+ "text2audio",
767
+ "text2audio_zero_shot"
751
768
  ],
752
769
  "multilingual": false,
753
770
  "language": "FR",
@@ -763,7 +780,8 @@
763
780
  "model_name": "MeloTTS-Japanese",
764
781
  "model_family": "MeloTTS",
765
782
  "model_ability": [
766
- "text2audio"
783
+ "text2audio",
784
+ "text2audio_zero_shot"
767
785
  ],
768
786
  "multilingual": false,
769
787
  "language": "JP",
@@ -779,7 +797,8 @@
779
797
  "model_name": "MeloTTS-Spanish",
780
798
  "model_family": "MeloTTS",
781
799
  "model_ability": [
782
- "text2audio"
800
+ "text2audio",
801
+ "text2audio_zero_shot"
783
802
  ],
784
803
  "multilingual": false,
785
804
  "language": "ES",
@@ -795,7 +814,8 @@
795
814
  "model_name": "MeloTTS-Chinese",
796
815
  "model_family": "MeloTTS",
797
816
  "model_ability": [
798
- "text2audio"
817
+ "text2audio",
818
+ "text2audio_zero_shot"
799
819
  ],
800
820
  "multilingual": false,
801
821
  "language": "ZH",
@@ -811,7 +831,8 @@
811
831
  "model_name": "MeloTTS-Korean",
812
832
  "model_family": "MeloTTS",
813
833
  "model_ability": [
814
- "text2audio"
834
+ "text2audio",
835
+ "text2audio_zero_shot"
815
836
  ],
816
837
  "multilingual": false,
817
838
  "language": "KR",
@@ -827,7 +848,8 @@
827
848
  "model_name": "Kokoro-82M",
828
849
  "model_family": "Kokoro",
829
850
  "model_ability": [
830
- "text2audio"
851
+ "text2audio",
852
+ "text2audio_zero_shot"
831
853
  ],
832
854
  "multilingual": true,
833
855
  "model_src": {
@@ -840,13 +862,34 @@
840
862
  "model_revision": "master"
841
863
  }
842
864
  }
865
+ },
866
+ {
867
+ "version": 2,
868
+ "model_name": "Kokoro-82M-v1.1-zh",
869
+ "model_family": "Kokoro-zh",
870
+ "model_ability": [
871
+ "text2audio",
872
+ "text2audio_zero_shot"
873
+ ],
874
+ "multilingual": false,
875
+ "model_src": {
876
+ "huggingface": {
877
+ "model_id": "hexgrad/Kokoro-82M-v1.1-zh",
878
+ "model_revision": "01e7505bd6a7a2ac4975463114c3a7650a9f7218"
879
+ },
880
+ "modelscope": {
881
+ "model_id": "AI-ModelScope/Kokoro-82M-v1.1-zh",
882
+ "model_revision": "master"
883
+ }
884
+ }
843
885
  },
844
886
  {
845
887
  "version": 2,
846
888
  "model_name": "Kokoro-82M-MLX",
847
889
  "model_family": "Kokoro-MLX",
848
890
  "model_ability": [
849
- "text2audio"
891
+ "text2audio",
892
+ "text2audio_zero_shot"
850
893
  ],
851
894
  "multilingual": true,
852
895
  "model_src": {
@@ -874,7 +917,8 @@
874
917
  "model_name": "MegaTTS3",
875
918
  "model_family": "MegaTTS",
876
919
  "model_ability": [
877
- "text2audio"
920
+ "text2audio",
921
+ "text2audio_zero_shot"
878
922
  ],
879
923
  "multilingual": true,
880
924
  "model_src": {
@@ -58,6 +58,11 @@ class FlagEmbeddingModel(EmbeddingModel):
58
58
  self._return_sparse = return_sparse
59
59
 
60
60
  def load(self):
61
+ # add truncate_dim args hint
62
+ if self._kwargs and "dimensions" in self._kwargs:
63
+ raise NotImplementedError(
64
+ "Flag embedder does not support dimensions argument now."
65
+ )
61
66
  try:
62
67
  from FlagEmbedding import BGEM3FlagModel
63
68
  except ImportError:
@@ -22,7 +22,7 @@ import queue
22
22
  import sys
23
23
  from typing import List, Optional, Union
24
24
 
25
- import orjson
25
+ from packaging import version
26
26
 
27
27
  from ....types import Embedding
28
28
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
@@ -69,15 +69,29 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
69
69
  return sys.platform.startswith("linux")
70
70
 
71
71
  def load(self):
72
+ # add truncate_dim args hint
73
+ if "dimensions" in self._kwargs:
74
+ raise NotImplementedError(
75
+ "LlamaCpp embedder does not support dimensions argument now."
76
+ )
72
77
  try:
73
78
  from xllamacpp import (
74
79
  CommonParams,
75
80
  Server,
81
+ __version__,
76
82
  estimate_gpu_layers,
77
83
  get_device_info,
78
84
  ggml_backend_dev_type,
79
85
  llama_pooling_type,
80
86
  )
87
+
88
+ try:
89
+ if version.parse(__version__) < version.parse("0.2.0"):
90
+ raise RuntimeError(
91
+ "Please update xllamacpp to >= 0.2.0 by `pip install -U xllamacpp`"
92
+ )
93
+ except version.InvalidVersion:
94
+ pass # If the version parse failed, we just skip the version check.
81
95
  except ImportError:
82
96
  error_message = "Failed to import module 'xllamacpp'"
83
97
  installation_guide = ["Please make sure 'xllamacpp' is installed. "]
@@ -162,7 +176,8 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
162
176
  )
163
177
  logger.info("Estimate num gpu layers: %s", estimate)
164
178
  if estimate.tensor_split:
165
- params.tensor_split = estimate.tensor_split
179
+ for i in range(len(estimate.tensor_split)):
180
+ params.tensor_split[i] = estimate.tensor_split[i]
166
181
  else:
167
182
  params.n_gpu_layers = estimate.layers
168
183
  except Exception as e:
@@ -190,24 +205,12 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
190
205
  model_uid: Optional[str] = kwargs.pop("model_uid", None)
191
206
  if model_uid:
192
207
  data["model"] = model_uid
193
- prompt_json = orjson.dumps(data)
194
-
195
- def _error_callback(err):
196
- try:
197
- msg = orjson.loads(err)
198
- q.put(_Error(msg))
199
- except Exception as e:
200
- q.put(_Error(str(e)))
201
-
202
- def _ok_callback(ok):
203
- try:
204
- res = orjson.loads(ok)
205
- q.put(res)
206
- except Exception as e:
207
- q.put(_Error(str(e)))
208
-
209
208
  try:
210
- self._llm.handle_embeddings(prompt_json, _error_callback, _ok_callback)
209
+ res = self._llm.handle_embeddings(data)
210
+ if res.get("code"):
211
+ q.put(_Error(res))
212
+ else:
213
+ q.put(res)
211
214
  except Exception as ex:
212
215
  q.put(_Error(str(ex)))
213
216
  q.put(_Done)
@@ -71,6 +71,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
71
71
  )
72
72
  torch_dtype = torch.float32
73
73
 
74
+ dimensions = self._kwargs.get("dimensions")
75
+ assert dimensions is None or isinstance(dimensions, int), (
76
+ "The `dimensions` argument must be an integer, "
77
+ f"but got {type(dimensions)}: {dimensions}"
78
+ )
79
+
74
80
  if (
75
81
  "gte" in self.model_family.model_name.lower()
76
82
  and "qwen2" in self.model_family.model_name.lower()
@@ -82,6 +88,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
82
88
  self._model_path,
83
89
  device=self._device,
84
90
  model_kwargs=model_kwargs,
91
+ truncate_dim=dimensions,
85
92
  )
86
93
  elif "qwen3" in self.model_family.model_name.lower():
87
94
  # qwen3 embedding
@@ -106,6 +113,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
106
113
  device=self._device,
107
114
  model_kwargs=model_kwargs,
108
115
  tokenizer_kwargs=tokenizer_kwargs,
116
+ truncate_dim=dimensions,
109
117
  )
110
118
  else:
111
119
  model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
@@ -114,6 +122,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
114
122
  device=self._device,
115
123
  model_kwargs=model_kwargs,
116
124
  trust_remote_code=True,
125
+ truncate_dim=dimensions,
117
126
  )
118
127
 
119
128
  if hasattr(self._model, "tokenizer"):
@@ -256,10 +265,10 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
256
265
  "clip" in self.model_family.model_name.lower()
257
266
  or "jina-embeddings-v4" in self.model_family.model_name.lower()
258
267
  ):
259
- if "input_ids" in features and hasattr(
260
- features["input_ids"], "numel"
261
- ):
262
- all_token_nums += features["input_ids"].numel()
268
+ # support input_ids and text_input_ids
269
+ for key in ["input_ids", "text_input_ids"]:
270
+ if key in features and hasattr(features[key], "numel"):
271
+ all_token_nums += features[key].numel()
263
272
  if "pixel_values" in features and hasattr(
264
273
  features["pixel_values"], "numel"
265
274
  ):
@@ -270,6 +279,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
270
279
  with torch.no_grad():
271
280
  out_features = model.forward(features, **kwargs)
272
281
 
282
+ from sentence_transformers.util import truncate_embeddings
283
+
284
+ out_features["sentence_embedding"] = truncate_embeddings(
285
+ out_features["sentence_embedding"], model.truncate_dim
286
+ )
287
+
273
288
  if output_value == "token_embeddings":
274
289
  embeddings = []
275
290
  for token_emb, attention in zip(
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import importlib.util
16
+ import json
16
17
  import logging
17
18
  from typing import List, Union
18
19
 
@@ -25,7 +26,6 @@ SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
25
26
 
26
27
 
27
28
  class VLLMEmbeddingModel(EmbeddingModel):
28
-
29
29
  def __init__(self, *args, **kwargs):
30
30
  super().__init__(*args, **kwargs)
31
31
  self._context_length = None
@@ -42,13 +42,31 @@ class VLLMEmbeddingModel(EmbeddingModel):
42
42
  ]
43
43
 
44
44
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
45
+ if self.model_family.model_name in {
46
+ "Qwen3-Embedding-0.6B",
47
+ "Qwen3-Embedding-4B",
48
+ "Qwen3-Embedding-8B",
49
+ }:
50
+ if "hf_overrides" not in self._kwargs:
51
+ self._kwargs["hf_overrides"] = {
52
+ "is_matryoshka": True,
53
+ }
54
+ elif isinstance(self._kwargs["hf_overrides"], dict):
55
+ self._kwargs["hf_overrides"].update(
56
+ is_matryoshka=True,
57
+ )
58
+ elif isinstance(self._kwargs["hf_overrides"], str):
59
+ self._kwargs["hf_overrides"] = json.loads(self._kwargs["hf_overrides"])
60
+ self._kwargs["hf_overrides"].update(
61
+ is_matryoshka=True,
62
+ )
45
63
 
46
64
  self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
47
65
  self._tokenizer = self._model.get_tokenizer()
48
66
 
49
67
  @staticmethod
50
68
  def _get_detailed_instruct(task_description: str, query: str) -> str:
51
- return f"Instruct: {task_description}\nQuery:{query}"
69
+ return f"Instruct: {task_description}\nQuery:{query}" # noqa: E231
52
70
 
53
71
  @cache_clean
54
72
  def create_embedding(
@@ -56,14 +74,15 @@ class VLLMEmbeddingModel(EmbeddingModel):
56
74
  sentences: Union[str, List[str]],
57
75
  **kwargs,
58
76
  ):
77
+ from packaging.version import Version
78
+ from vllm import PoolingParams
79
+ from vllm import __version__ as vllm_version
80
+
59
81
  sentences = self._fix_langchain_openai_inputs(sentences)
60
82
  model_uid = kwargs.pop("model_uid", None)
61
83
 
62
84
  normalize_embedding = kwargs.get("normalize_embedding", True)
63
- if not normalize_embedding:
64
- raise ValueError(
65
- "vllm embedding engine does not support setting `normalize_embedding=False`"
66
- )
85
+ dimensions = kwargs.get("dimensions", None)
67
86
 
68
87
  assert self._model is not None
69
88
 
@@ -92,8 +111,21 @@ class VLLMEmbeddingModel(EmbeddingModel):
92
111
  sentences = truncated_sentences[0]
93
112
  else:
94
113
  sentences = truncated_sentences
95
-
96
- outputs = self._model.embed(sentences, use_tqdm=False)
114
+ if Version(vllm_version) > Version("0.10.1"):
115
+ pool_params = PoolingParams(
116
+ dimensions=dimensions, normalize=normalize_embedding
117
+ )
118
+ else:
119
+ if not normalize_embedding:
120
+ raise ValueError(
121
+ f"vLLM version {vllm_version} does not support "
122
+ f"unnormalized embeddings. "
123
+ f"Please upgrade to v0.10.1 or later."
124
+ )
125
+ pool_params = PoolingParams(dimensions=dimensions)
126
+ outputs = self._model.embed(
127
+ sentences, use_tqdm=False, pooling_params=pool_params
128
+ )
97
129
  embedding_list = []
98
130
  all_token_nums = 0
99
131
  for index, output in enumerate(outputs):
@@ -60,3 +60,59 @@ class ImageCacheManager(CacheManager):
60
60
  raise NotImplementedError
61
61
 
62
62
  return full_path
63
+
64
+ def cache_lightning(self, lightning_version: Optional[str] = None):
65
+ from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
66
+ from .core import ImageModelFamilyV2
67
+
68
+ if not lightning_version:
69
+ return None
70
+
71
+ assert isinstance(self._model_family, ImageModelFamilyV2)
72
+ cache_dir = self.get_cache_dir()
73
+
74
+ if not self._model_family.lightning_model_file_name_template:
75
+ raise NotImplementedError(
76
+ f"{self._model_family.model_name} does not support lightning"
77
+ )
78
+ if lightning_version not in (self._model_family.lightning_versions or []):
79
+ raise ValueError(
80
+ f"Cannot support lightning version {lightning_version}, "
81
+ f"available lightning version: {self._model_family.lightning_versions}"
82
+ )
83
+
84
+ filename = self._model_family.lightning_model_file_name_template.format(lightning_version=lightning_version) # type: ignore
85
+ full_path = os.path.join(cache_dir, filename)
86
+
87
+ if self._model_family.model_hub == "huggingface":
88
+ import huggingface_hub
89
+
90
+ use_symlinks = {}
91
+ if not IS_NEW_HUGGINGFACE_HUB:
92
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
93
+ download_file_path = retry_download(
94
+ huggingface_hub.hf_hub_download,
95
+ self._model_family.model_name,
96
+ None,
97
+ self._model_family.lightning_model_id,
98
+ filename=filename,
99
+ **use_symlinks,
100
+ )
101
+ if IS_NEW_HUGGINGFACE_HUB:
102
+ symlink_local_file(download_file_path, cache_dir, filename)
103
+ elif self._model_family.model_hub == "modelscope":
104
+ from modelscope.hub.file_download import model_file_download
105
+
106
+ download_file_path = retry_download(
107
+ model_file_download,
108
+ self._model_family.model_name,
109
+ None,
110
+ self._model_family.lightning_model_id,
111
+ filename,
112
+ revision=self._model_family.model_revision,
113
+ )
114
+ symlink_local_file(download_file_path, cache_dir, filename)
115
+ else:
116
+ raise NotImplementedError
117
+
118
+ return full_path