xinference 1.9.0__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (74) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/model.py +3 -4
  3. xinference/core/worker.py +4 -1
  4. xinference/deploy/cmdline.py +2 -0
  5. xinference/deploy/test/test_cmdline.py +1 -1
  6. xinference/model/audio/cosyvoice.py +0 -1
  7. xinference/model/audio/model_spec.json +44 -20
  8. xinference/model/embedding/flag/core.py +5 -0
  9. xinference/model/embedding/llama_cpp/core.py +22 -19
  10. xinference/model/embedding/sentence_transformers/core.py +15 -0
  11. xinference/model/embedding/vllm/core.py +33 -7
  12. xinference/model/image/cache_manager.py +56 -0
  13. xinference/model/image/core.py +9 -0
  14. xinference/model/image/model_spec.json +114 -6
  15. xinference/model/image/stable_diffusion/core.py +141 -31
  16. xinference/model/llm/llama_cpp/core.py +41 -40
  17. xinference/model/llm/llm_family.json +395 -3
  18. xinference/model/llm/transformers/core.py +5 -11
  19. xinference/model/llm/utils.py +1 -1
  20. xinference/model/llm/vllm/core.py +6 -0
  21. xinference/model/rerank/core.py +3 -0
  22. xinference/model/rerank/sentence_transformers/core.py +1 -1
  23. xinference/model/rerank/vllm/core.py +56 -6
  24. xinference/model/utils.py +1 -2
  25. xinference/model/video/model_spec.json +95 -1
  26. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  27. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  28. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  29. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  30. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  31. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  32. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  33. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  34. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  35. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  36. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  37. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  38. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  39. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  40. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  41. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  42. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  43. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  44. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  45. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  46. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  47. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  48. xinference/ui/gradio/chat_interface.py +2 -0
  49. xinference/ui/gradio/media_interface.py +353 -7
  50. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  51. xinference/ui/web/ui/build/index.html +1 -1
  52. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  53. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  54. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  55. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  56. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  57. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  58. xinference/ui/web/ui/src/locales/en.json +2 -0
  59. xinference/ui/web/ui/src/locales/ja.json +2 -0
  60. xinference/ui/web/ui/src/locales/ko.json +2 -0
  61. xinference/ui/web/ui/src/locales/zh.json +2 -0
  62. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/METADATA +10 -10
  63. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/RECORD +68 -67
  64. xinference/ui/web/ui/build/static/js/main.4918643a.js +0 -3
  65. xinference/ui/web/ui/build/static/js/main.4918643a.js.map +0 -1
  66. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  67. xinference/ui/web/ui/node_modules/.cache/babel-loader/89179f8f51887b9167721860a12412549ff04f78162e921a7b6aa6532646deb2.json +0 -1
  68. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  69. xinference/ui/web/ui/node_modules/.cache/babel-loader/9dc5cfc67dd0617b0272aeef8651f1589b2155a4ff1fd72ad3166b217089b619.json +0 -1
  70. /xinference/ui/web/ui/build/static/js/{main.4918643a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  71. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  72. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  73. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  74. {xinference-1.9.0.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-08-16T21:34:08+0800",
11
+ "date": "2025-08-30T03:57:39+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "38e0401f83799f57d42ef948c57782466b8e4777",
15
- "version": "1.9.0"
14
+ "full-revisionid": "b2d793d0b4a0af632932eb63dbeb1bc91b5b3d74",
15
+ "version": "1.9.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
xinference/core/model.py CHANGED
@@ -882,10 +882,9 @@ class ModelActor(xo.StatelessActor, CancelMixin):
882
882
  **kwargs,
883
883
  ):
884
884
  if hasattr(self._model, "text_to_image"):
885
- # Directly delegate to model, let model decide how to handle (batching or not)
886
- progressor = kwargs["progressor"] = await self._get_progressor(
887
- kwargs.pop("request_id", None)
888
- )
885
+ # Get progressor (don't pop request_id, let _call_wrapper handle cancellation)
886
+ request_id = kwargs.get("request_id")
887
+ progressor = kwargs["progressor"] = await self._get_progressor(request_id) # type: ignore
889
888
  with progressor:
890
889
  return await self._call_wrapper_json(
891
890
  self._model.text_to_image,
xinference/core/worker.py CHANGED
@@ -827,10 +827,13 @@ class WorkerActor(xo.StatelessActor):
827
827
  settings: Optional[VirtualEnvSettings],
828
828
  virtual_env_packages: Optional[List[str]],
829
829
  ):
830
- if not settings or not settings.packages:
830
+ if (not settings or not settings.packages) and not virtual_env_packages:
831
831
  # no settings or no packages
832
832
  return
833
833
 
834
+ if settings is None:
835
+ settings = VirtualEnvSettings(packages=virtual_env_packages)
836
+
834
837
  if settings.inherit_pip_config:
835
838
  # inherit pip config
836
839
  pip_config = get_pip_config_args()
@@ -1345,6 +1345,8 @@ def model_chat(
1345
1345
  messages,
1346
1346
  generate_config={"stream": stream, "max_tokens": max_tokens},
1347
1347
  ):
1348
+ if not chunk["choices"]:
1349
+ continue
1348
1350
  delta = chunk["choices"][0]["delta"]
1349
1351
  if "content" not in delta:
1350
1352
  continue
@@ -87,7 +87,7 @@ def test_cmdline(setup, stream, model_uid):
87
87
  ],
88
88
  )
89
89
  assert result.exit_code == 0
90
- assert model_uid in result.stdout
90
+ assert model_uid in result.output
91
91
 
92
92
  # model generate
93
93
  result = runner.invoke(
@@ -60,7 +60,6 @@ class CosyVoiceModel:
60
60
  from cosyvoice.cli.cosyvoice import CosyVoice2 as CosyVoice
61
61
 
62
62
  self._is_cosyvoice2 = True
63
- kwargs = {"use_flow_cache": self._kwargs.get("use_flow_cache", False)}
64
63
  else:
65
64
  from cosyvoice.cli.cosyvoice import CosyVoice
66
65
 
@@ -525,7 +525,8 @@
525
525
  "model_name": "ChatTTS",
526
526
  "model_family": "ChatTTS",
527
527
  "model_ability": [
528
- "text2audio"
528
+ "text2audio",
529
+ "text2audio_zero_shot"
529
530
  ],
530
531
  "multilingual": true,
531
532
  "virtualenv": {
@@ -551,7 +552,8 @@
551
552
  "model_name": "CosyVoice-300M",
552
553
  "model_family": "CosyVoice",
553
554
  "model_ability": [
554
- "text2audio"
555
+ "text2audio",
556
+ "text2audio_voice_cloning"
555
557
  ],
556
558
  "multilingual": true,
557
559
  "model_src": {
@@ -570,7 +572,8 @@
570
572
  "model_name": "CosyVoice-300M-SFT",
571
573
  "model_family": "CosyVoice",
572
574
  "model_ability": [
573
- "text2audio"
575
+ "text2audio",
576
+ "text2audio_zero_shot"
574
577
  ],
575
578
  "multilingual": true,
576
579
  "model_src": {
@@ -589,7 +592,8 @@
589
592
  "model_name": "CosyVoice-300M-Instruct",
590
593
  "model_family": "CosyVoice",
591
594
  "model_ability": [
592
- "text2audio"
595
+ "text2audio",
596
+ "text2audio_zero_shot"
593
597
  ],
594
598
  "multilingual": true,
595
599
  "model_src": {
@@ -608,7 +612,9 @@
608
612
  "model_name": "CosyVoice2-0.5B",
609
613
  "model_family": "CosyVoice",
610
614
  "model_ability": [
611
- "text2audio"
615
+ "text2audio",
616
+ "text2audio_zero_shot",
617
+ "text2audio_voice_cloning"
612
618
  ],
613
619
  "multilingual": true,
614
620
  "virtualenv": {
@@ -625,7 +631,8 @@
625
631
  "HyperPyYAML",
626
632
  "onnxruntime>=1.16.0",
627
633
  "pyworld>=0.3.4",
628
- "WeTextProcessing<1.0.4",
634
+ "wetext==0.0.9",
635
+ "transformers==4.51.3",
629
636
  "#system_numpy#",
630
637
  "#system_torch#"
631
638
  ]
@@ -646,7 +653,9 @@
646
653
  "model_name": "FishSpeech-1.5",
647
654
  "model_family": "FishAudio",
648
655
  "model_ability": [
649
- "text2audio"
656
+ "text2audio",
657
+ "text2audio_zero_shot",
658
+ "text2audio_voice_cloning"
650
659
  ],
651
660
  "multilingual": true,
652
661
  "model_src": {
@@ -665,7 +674,9 @@
665
674
  "model_name": "F5-TTS",
666
675
  "model_family": "F5-TTS",
667
676
  "model_ability": [
668
- "text2audio"
677
+ "text2audio",
678
+ "text2audio_zero_shot",
679
+ "text2audio_voice_cloning"
669
680
  ],
670
681
  "multilingual": true,
671
682
  "model_src": {
@@ -684,7 +695,9 @@
684
695
  "model_name": "F5-TTS-MLX",
685
696
  "model_family": "F5-TTS-MLX",
686
697
  "model_ability": [
687
- "text2audio"
698
+ "text2audio",
699
+ "text2audio_zero_shot",
700
+ "text2audio_voice_cloning"
688
701
  ],
689
702
  "multilingual": true,
690
703
  "model_src": {
@@ -699,7 +712,8 @@
699
712
  "model_name": "MeloTTS-English",
700
713
  "model_family": "MeloTTS",
701
714
  "model_ability": [
702
- "text2audio"
715
+ "text2audio",
716
+ "text2audio_zero_shot"
703
717
  ],
704
718
  "multilingual": false,
705
719
  "language": "EN",
@@ -715,7 +729,8 @@
715
729
  "model_name": "MeloTTS-English-v2",
716
730
  "model_family": "MeloTTS",
717
731
  "model_ability": [
718
- "text2audio"
732
+ "text2audio",
733
+ "text2audio_zero_shot"
719
734
  ],
720
735
  "multilingual": false,
721
736
  "language": "EN",
@@ -731,7 +746,8 @@
731
746
  "model_name": "MeloTTS-English-v3",
732
747
  "model_family": "MeloTTS",
733
748
  "model_ability": [
734
- "text2audio"
749
+ "text2audio",
750
+ "text2audio_zero_shot"
735
751
  ],
736
752
  "multilingual": false,
737
753
  "language": "EN",
@@ -747,7 +763,8 @@
747
763
  "model_name": "MeloTTS-French",
748
764
  "model_family": "MeloTTS",
749
765
  "model_ability": [
750
- "text2audio"
766
+ "text2audio",
767
+ "text2audio_zero_shot"
751
768
  ],
752
769
  "multilingual": false,
753
770
  "language": "FR",
@@ -763,7 +780,8 @@
763
780
  "model_name": "MeloTTS-Japanese",
764
781
  "model_family": "MeloTTS",
765
782
  "model_ability": [
766
- "text2audio"
783
+ "text2audio",
784
+ "text2audio_zero_shot"
767
785
  ],
768
786
  "multilingual": false,
769
787
  "language": "JP",
@@ -779,7 +797,8 @@
779
797
  "model_name": "MeloTTS-Spanish",
780
798
  "model_family": "MeloTTS",
781
799
  "model_ability": [
782
- "text2audio"
800
+ "text2audio",
801
+ "text2audio_zero_shot"
783
802
  ],
784
803
  "multilingual": false,
785
804
  "language": "ES",
@@ -795,7 +814,8 @@
795
814
  "model_name": "MeloTTS-Chinese",
796
815
  "model_family": "MeloTTS",
797
816
  "model_ability": [
798
- "text2audio"
817
+ "text2audio",
818
+ "text2audio_zero_shot"
799
819
  ],
800
820
  "multilingual": false,
801
821
  "language": "ZH",
@@ -811,7 +831,8 @@
811
831
  "model_name": "MeloTTS-Korean",
812
832
  "model_family": "MeloTTS",
813
833
  "model_ability": [
814
- "text2audio"
834
+ "text2audio",
835
+ "text2audio_zero_shot"
815
836
  ],
816
837
  "multilingual": false,
817
838
  "language": "KR",
@@ -827,7 +848,8 @@
827
848
  "model_name": "Kokoro-82M",
828
849
  "model_family": "Kokoro",
829
850
  "model_ability": [
830
- "text2audio"
851
+ "text2audio",
852
+ "text2audio_zero_shot"
831
853
  ],
832
854
  "multilingual": true,
833
855
  "model_src": {
@@ -846,7 +868,8 @@
846
868
  "model_name": "Kokoro-82M-MLX",
847
869
  "model_family": "Kokoro-MLX",
848
870
  "model_ability": [
849
- "text2audio"
871
+ "text2audio",
872
+ "text2audio_zero_shot"
850
873
  ],
851
874
  "multilingual": true,
852
875
  "model_src": {
@@ -874,7 +897,8 @@
874
897
  "model_name": "MegaTTS3",
875
898
  "model_family": "MegaTTS",
876
899
  "model_ability": [
877
- "text2audio"
900
+ "text2audio",
901
+ "text2audio_zero_shot"
878
902
  ],
879
903
  "multilingual": true,
880
904
  "model_src": {
@@ -58,6 +58,11 @@ class FlagEmbeddingModel(EmbeddingModel):
58
58
  self._return_sparse = return_sparse
59
59
 
60
60
  def load(self):
61
+ # add truncate_dim args hint
62
+ if self._kwargs and "dimensions" in self._kwargs:
63
+ raise NotImplementedError(
64
+ "Flag embedder does not support dimensions argument now."
65
+ )
61
66
  try:
62
67
  from FlagEmbedding import BGEM3FlagModel
63
68
  except ImportError:
@@ -22,7 +22,7 @@ import queue
22
22
  import sys
23
23
  from typing import List, Optional, Union
24
24
 
25
- import orjson
25
+ from packaging import version
26
26
 
27
27
  from ....types import Embedding
28
28
  from ..core import EmbeddingModel, EmbeddingModelFamilyV2, EmbeddingSpecV1
@@ -69,15 +69,29 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
69
69
  return sys.platform.startswith("linux")
70
70
 
71
71
  def load(self):
72
+ # add truncate_dim args hint
73
+ if "dimensions" in self._kwargs:
74
+ raise NotImplementedError(
75
+ "LlamaCpp embedder does not support dimensions argument now."
76
+ )
72
77
  try:
73
78
  from xllamacpp import (
74
79
  CommonParams,
75
80
  Server,
81
+ __version__,
76
82
  estimate_gpu_layers,
77
83
  get_device_info,
78
84
  ggml_backend_dev_type,
79
85
  llama_pooling_type,
80
86
  )
87
+
88
+ try:
89
+ if version.parse(__version__) < version.parse("0.2.0"):
90
+ raise RuntimeError(
91
+ "Please update xllamacpp to >= 0.2.0 by `pip install -U xllamacpp`"
92
+ )
93
+ except version.InvalidVersion:
94
+ pass # If the version parse failed, we just skip the version check.
81
95
  except ImportError:
82
96
  error_message = "Failed to import module 'xllamacpp'"
83
97
  installation_guide = ["Please make sure 'xllamacpp' is installed. "]
@@ -162,7 +176,8 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
162
176
  )
163
177
  logger.info("Estimate num gpu layers: %s", estimate)
164
178
  if estimate.tensor_split:
165
- params.tensor_split = estimate.tensor_split
179
+ for i in range(len(estimate.tensor_split)):
180
+ params.tensor_split[i] = estimate.tensor_split[i]
166
181
  else:
167
182
  params.n_gpu_layers = estimate.layers
168
183
  except Exception as e:
@@ -190,24 +205,12 @@ class XllamaCppEmbeddingModel(EmbeddingModel):
190
205
  model_uid: Optional[str] = kwargs.pop("model_uid", None)
191
206
  if model_uid:
192
207
  data["model"] = model_uid
193
- prompt_json = orjson.dumps(data)
194
-
195
- def _error_callback(err):
196
- try:
197
- msg = orjson.loads(err)
198
- q.put(_Error(msg))
199
- except Exception as e:
200
- q.put(_Error(str(e)))
201
-
202
- def _ok_callback(ok):
203
- try:
204
- res = orjson.loads(ok)
205
- q.put(res)
206
- except Exception as e:
207
- q.put(_Error(str(e)))
208
-
209
208
  try:
210
- self._llm.handle_embeddings(prompt_json, _error_callback, _ok_callback)
209
+ res = self._llm.handle_embeddings(data)
210
+ if res.get("code"):
211
+ q.put(_Error(res))
212
+ else:
213
+ q.put(res)
211
214
  except Exception as ex:
212
215
  q.put(_Error(str(ex)))
213
216
  q.put(_Done)
@@ -71,6 +71,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
71
71
  )
72
72
  torch_dtype = torch.float32
73
73
 
74
+ dimensions = self._kwargs.get("dimensions")
75
+ assert dimensions is None or isinstance(dimensions, int), (
76
+ "The `dimensions` argument must be an integer, "
77
+ f"but got {type(dimensions)}: {dimensions}"
78
+ )
79
+
74
80
  if (
75
81
  "gte" in self.model_family.model_name.lower()
76
82
  and "qwen2" in self.model_family.model_name.lower()
@@ -82,6 +88,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
82
88
  self._model_path,
83
89
  device=self._device,
84
90
  model_kwargs=model_kwargs,
91
+ truncate_dim=dimensions,
85
92
  )
86
93
  elif "qwen3" in self.model_family.model_name.lower():
87
94
  # qwen3 embedding
@@ -106,6 +113,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
106
113
  device=self._device,
107
114
  model_kwargs=model_kwargs,
108
115
  tokenizer_kwargs=tokenizer_kwargs,
116
+ truncate_dim=dimensions,
109
117
  )
110
118
  else:
111
119
  model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
@@ -114,6 +122,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
114
122
  device=self._device,
115
123
  model_kwargs=model_kwargs,
116
124
  trust_remote_code=True,
125
+ truncate_dim=dimensions,
117
126
  )
118
127
 
119
128
  if hasattr(self._model, "tokenizer"):
@@ -270,6 +279,12 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
270
279
  with torch.no_grad():
271
280
  out_features = model.forward(features, **kwargs)
272
281
 
282
+ from sentence_transformers.util import truncate_embeddings
283
+
284
+ out_features["sentence_embedding"] = truncate_embeddings(
285
+ out_features["sentence_embedding"], model.truncate_dim
286
+ )
287
+
273
288
  if output_value == "token_embeddings":
274
289
  embeddings = []
275
290
  for token_emb, attention in zip(
@@ -25,7 +25,6 @@ SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
25
25
 
26
26
 
27
27
  class VLLMEmbeddingModel(EmbeddingModel):
28
-
29
28
  def __init__(self, *args, **kwargs):
30
29
  super().__init__(*args, **kwargs)
31
30
  self._context_length = None
@@ -42,6 +41,19 @@ class VLLMEmbeddingModel(EmbeddingModel):
42
41
  ]
43
42
 
44
43
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
44
+ if self.model_family.model_name in {
45
+ "Qwen3-Embedding-0.6B",
46
+ "Qwen3-Embedding-4B",
47
+ "Qwen3-Embedding-8B",
48
+ }:
49
+ if "hf_overrides" not in self._kwargs:
50
+ self._kwargs["hf_overrides"] = {
51
+ "is_matryoshka": True,
52
+ }
53
+ elif isinstance(self._kwargs["hf_overrides"], dict):
54
+ self._kwargs["hf_overrides"].update(
55
+ is_matryoshka=True,
56
+ )
45
57
 
46
58
  self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
47
59
  self._tokenizer = self._model.get_tokenizer()
@@ -56,14 +68,15 @@ class VLLMEmbeddingModel(EmbeddingModel):
56
68
  sentences: Union[str, List[str]],
57
69
  **kwargs,
58
70
  ):
71
+ from packaging.version import Version
72
+ from vllm import PoolingParams
73
+ from vllm import __version__ as vllm_version
74
+
59
75
  sentences = self._fix_langchain_openai_inputs(sentences)
60
76
  model_uid = kwargs.pop("model_uid", None)
61
77
 
62
78
  normalize_embedding = kwargs.get("normalize_embedding", True)
63
- if not normalize_embedding:
64
- raise ValueError(
65
- "vllm embedding engine does not support setting `normalize_embedding=False`"
66
- )
79
+ dimensions = kwargs.get("dimensions", None)
67
80
 
68
81
  assert self._model is not None
69
82
 
@@ -92,8 +105,21 @@ class VLLMEmbeddingModel(EmbeddingModel):
92
105
  sentences = truncated_sentences[0]
93
106
  else:
94
107
  sentences = truncated_sentences
95
-
96
- outputs = self._model.embed(sentences, use_tqdm=False)
108
+ if Version(vllm_version) > Version("0.10.1"):
109
+ pool_params = PoolingParams(
110
+ dimensions=dimensions, normalize=normalize_embedding
111
+ )
112
+ else:
113
+ if not normalize_embedding:
114
+ raise ValueError(
115
+ f"vLLM version {vllm_version} does not support "
116
+ f"unnormalized embeddings. "
117
+ f"Please upgrade to v0.10.1 or later."
118
+ )
119
+ pool_params = PoolingParams(dimensions=dimensions)
120
+ outputs = self._model.embed(
121
+ sentences, use_tqdm=False, pooling_params=pool_params
122
+ )
97
123
  embedding_list = []
98
124
  all_token_nums = 0
99
125
  for index, output in enumerate(outputs):
@@ -60,3 +60,59 @@ class ImageCacheManager(CacheManager):
60
60
  raise NotImplementedError
61
61
 
62
62
  return full_path
63
+
64
+ def cache_lightning(self, lightning_version: Optional[str] = None):
65
+ from ..utils import IS_NEW_HUGGINGFACE_HUB, retry_download, symlink_local_file
66
+ from .core import ImageModelFamilyV2
67
+
68
+ if not lightning_version:
69
+ return None
70
+
71
+ assert isinstance(self._model_family, ImageModelFamilyV2)
72
+ cache_dir = self.get_cache_dir()
73
+
74
+ if not self._model_family.lightning_model_file_name_template:
75
+ raise NotImplementedError(
76
+ f"{self._model_family.model_name} does not support lightning"
77
+ )
78
+ if lightning_version not in (self._model_family.lightning_versions or []):
79
+ raise ValueError(
80
+ f"Cannot support lightning version {lightning_version}, "
81
+ f"available lightning version: {self._model_family.lightning_versions}"
82
+ )
83
+
84
+ filename = self._model_family.lightning_model_file_name_template.format(lightning_version=lightning_version) # type: ignore
85
+ full_path = os.path.join(cache_dir, filename)
86
+
87
+ if self._model_family.model_hub == "huggingface":
88
+ import huggingface_hub
89
+
90
+ use_symlinks = {}
91
+ if not IS_NEW_HUGGINGFACE_HUB:
92
+ use_symlinks = {"local_dir_use_symlinks": True, "local_dir": cache_dir}
93
+ download_file_path = retry_download(
94
+ huggingface_hub.hf_hub_download,
95
+ self._model_family.model_name,
96
+ None,
97
+ self._model_family.lightning_model_id,
98
+ filename=filename,
99
+ **use_symlinks,
100
+ )
101
+ if IS_NEW_HUGGINGFACE_HUB:
102
+ symlink_local_file(download_file_path, cache_dir, filename)
103
+ elif self._model_family.model_hub == "modelscope":
104
+ from modelscope.hub.file_download import model_file_download
105
+
106
+ download_file_path = retry_download(
107
+ model_file_download,
108
+ self._model_family.model_name,
109
+ None,
110
+ self._model_family.lightning_model_id,
111
+ filename,
112
+ revision=self._model_family.model_revision,
113
+ )
114
+ symlink_local_file(download_file_path, cache_dir, filename)
115
+ else:
116
+ raise NotImplementedError
117
+
118
+ return full_path
@@ -51,6 +51,10 @@ class ImageModelFamilyV2(CacheableModelSpec, ModelInstanceInfoMixin):
51
51
  gguf_model_id: Optional[str]
52
52
  gguf_quantizations: Optional[List[str]]
53
53
  gguf_model_file_name_template: Optional[str]
54
+ lightning_model_id: Optional[str]
55
+ lightning_versions: Optional[List[str]]
56
+ lightning_model_file_name_template: Optional[str]
57
+
54
58
  virtualenv: Optional[VirtualEnvSettings]
55
59
 
56
60
  class Config:
@@ -180,6 +184,8 @@ def create_image_model_instance(
180
184
  model_path: Optional[str] = None,
181
185
  gguf_quantization: Optional[str] = None,
182
186
  gguf_model_path: Optional[str] = None,
187
+ lightning_version: Optional[str] = None,
188
+ lightning_model_path: Optional[str] = None,
183
189
  **kwargs,
184
190
  ) -> Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model]:
185
191
  from .cache_manager import ImageCacheManager
@@ -235,6 +241,8 @@ def create_image_model_instance(
235
241
  model_path = cache_manager.cache()
236
242
  if not gguf_model_path and gguf_quantization:
237
243
  gguf_model_path = cache_manager.cache_gguf(gguf_quantization)
244
+ if not lightning_model_path and lightning_version:
245
+ lightning_model_path = cache_manager.cache_lightning(lightning_version)
238
246
  if peft_model_config is not None:
239
247
  lora_model = peft_model_config.peft_model
240
248
  lora_load_kwargs = peft_model_config.image_lora_load_kwargs
@@ -262,6 +270,7 @@ def create_image_model_instance(
262
270
  lora_fuse_kwargs=lora_fuse_kwargs,
263
271
  model_spec=model_spec,
264
272
  gguf_model_path=gguf_model_path,
273
+ lightning_model_path=lightning_model_path,
265
274
  **kwargs,
266
275
  )
267
276
  return model