xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +79 -2
  3. xinference/client/restful/restful_client.py +65 -3
  4. xinference/conftest.py +0 -7
  5. xinference/core/media_interface.py +132 -8
  6. xinference/core/model.py +44 -6
  7. xinference/core/scheduler.py +1 -10
  8. xinference/core/supervisor.py +8 -17
  9. xinference/core/worker.py +5 -27
  10. xinference/deploy/cmdline.py +6 -2
  11. xinference/model/audio/chattts.py +24 -39
  12. xinference/model/audio/cosyvoice.py +18 -30
  13. xinference/model/audio/funasr.py +42 -0
  14. xinference/model/audio/model_spec.json +71 -1
  15. xinference/model/audio/model_spec_modelscope.json +76 -2
  16. xinference/model/audio/utils.py +75 -0
  17. xinference/model/core.py +1 -0
  18. xinference/model/embedding/__init__.py +74 -18
  19. xinference/model/embedding/core.py +98 -589
  20. xinference/model/embedding/embed_family.py +133 -0
  21. xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
  22. xinference/model/embedding/flag/core.py +282 -0
  23. xinference/model/embedding/model_spec.json +24 -0
  24. xinference/model/embedding/model_spec_modelscope.json +24 -0
  25. xinference/model/embedding/sentence_transformers/__init__.py +13 -0
  26. xinference/model/embedding/sentence_transformers/core.py +399 -0
  27. xinference/model/embedding/vllm/core.py +95 -0
  28. xinference/model/image/model_spec.json +30 -3
  29. xinference/model/image/model_spec_modelscope.json +41 -2
  30. xinference/model/image/stable_diffusion/core.py +144 -53
  31. xinference/model/llm/__init__.py +6 -54
  32. xinference/model/llm/core.py +19 -5
  33. xinference/model/llm/llama_cpp/core.py +59 -3
  34. xinference/model/llm/llama_cpp/memory.py +457 -0
  35. xinference/model/llm/llm_family.json +247 -402
  36. xinference/model/llm/llm_family.py +88 -16
  37. xinference/model/llm/llm_family_modelscope.json +260 -421
  38. xinference/model/llm/llm_family_openmind_hub.json +0 -34
  39. xinference/model/llm/sglang/core.py +8 -0
  40. xinference/model/llm/transformers/__init__.py +27 -6
  41. xinference/model/llm/transformers/chatglm.py +4 -2
  42. xinference/model/llm/transformers/core.py +49 -28
  43. xinference/model/llm/transformers/deepseek_v2.py +6 -49
  44. xinference/model/llm/transformers/gemma3.py +119 -164
  45. xinference/model/llm/transformers/multimodal/__init__.py +13 -0
  46. xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
  47. xinference/model/llm/transformers/multimodal/core.py +205 -0
  48. xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
  49. xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
  50. xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
  51. xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
  52. xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
  53. xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
  54. xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
  55. xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
  56. xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
  57. xinference/model/llm/transformers/opt.py +4 -2
  58. xinference/model/llm/transformers/utils.py +6 -37
  59. xinference/model/llm/utils.py +11 -0
  60. xinference/model/llm/vllm/core.py +7 -0
  61. xinference/model/rerank/core.py +91 -3
  62. xinference/model/rerank/model_spec.json +24 -0
  63. xinference/model/rerank/model_spec_modelscope.json +24 -0
  64. xinference/model/rerank/utils.py +20 -2
  65. xinference/model/utils.py +38 -1
  66. xinference/model/video/diffusers.py +65 -3
  67. xinference/model/video/model_spec.json +31 -4
  68. xinference/model/video/model_spec_modelscope.json +32 -4
  69. xinference/web/ui/build/asset-manifest.json +6 -6
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/css/main.013f296b.css +2 -0
  72. xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
  73. xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
  74. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
  79. xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
  80. xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
  81. xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
  82. xinference/web/ui/src/locales/en.json +21 -8
  83. xinference/web/ui/src/locales/ja.json +224 -0
  84. xinference/web/ui/src/locales/ko.json +224 -0
  85. xinference/web/ui/src/locales/zh.json +21 -8
  86. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
  87. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
  88. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
  89. xinference/model/llm/transformers/cogvlm2.py +0 -442
  90. xinference/model/llm/transformers/cogvlm2_video.py +0 -333
  91. xinference/model/llm/transformers/deepseek_vl.py +0 -280
  92. xinference/model/llm/transformers/glm_edge_v.py +0 -213
  93. xinference/model/llm/transformers/intern_vl.py +0 -526
  94. xinference/model/llm/transformers/internlm2.py +0 -94
  95. xinference/model/llm/transformers/minicpmv25.py +0 -193
  96. xinference/model/llm/transformers/omnilmm.py +0 -132
  97. xinference/model/llm/transformers/qwen2_audio.py +0 -179
  98. xinference/model/llm/transformers/qwen_vl.py +0 -360
  99. xinference/thirdparty/omnilmm/LICENSE +0 -201
  100. xinference/thirdparty/omnilmm/chat.py +0 -218
  101. xinference/thirdparty/omnilmm/constants.py +0 -4
  102. xinference/thirdparty/omnilmm/conversation.py +0 -332
  103. xinference/thirdparty/omnilmm/model/__init__.py +0 -1
  104. xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
  105. xinference/thirdparty/omnilmm/model/resampler.py +0 -166
  106. xinference/thirdparty/omnilmm/model/utils.py +0 -578
  107. xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
  108. xinference/thirdparty/omnilmm/utils.py +0 -134
  109. xinference/web/ui/build/static/css/main.337afe76.css +0 -2
  110. xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
  111. xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
  112. xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
  113. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
  114. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
  115. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
  116. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
  117. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
  118. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
  120. /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
  121. /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
  122. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
  123. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
  124. {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
@@ -16,16 +16,13 @@ import gc
16
16
  import logging
17
17
  import os
18
18
  from collections import defaultdict
19
- from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
20
-
21
- import numpy as np
22
- import torch
19
+ from typing import Dict, List, Literal, Optional, Tuple, Union
23
20
 
24
21
  from ..._compat import ROOT_KEY, ErrorWrapper, ValidationError
25
22
  from ...device_utils import empty_cache
26
- from ...types import Embedding, EmbeddingData, EmbeddingUsage
27
23
  from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
28
24
  from ..utils import get_cache_dir, is_model_cached
25
+ from .embed_family import match_embedding
29
26
 
30
27
  logger = logging.getLogger(__name__)
31
28
 
@@ -49,6 +46,7 @@ def get_embedding_model_descriptions():
49
46
  return copy.deepcopy(EMBEDDING_MODEL_DESCRIPTIONS)
50
47
 
51
48
 
49
+ # this class define the basic info of embedding model
52
50
  class EmbeddingModelSpec(CacheableModelSpec):
53
51
  model_name: str
54
52
  dimensions: int
@@ -128,7 +126,11 @@ def get_cache_status(
128
126
  return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
129
127
 
130
128
 
131
- class EmbeddingModel:
129
+ import abc
130
+ from abc import abstractmethod
131
+
132
+
133
+ class EmbeddingModel(abc.ABC):
132
134
  def __init__(
133
135
  self,
134
136
  model_uid: str,
@@ -141,95 +143,36 @@ class EmbeddingModel:
141
143
  self._model_path = model_path
142
144
  self._device = device
143
145
  self._model = None
146
+ self._tokenizer = None
144
147
  self._counter = 0
145
148
  self._model_spec = model_spec
149
+ self._model_name = self._model_spec.model_name
146
150
  self._kwargs = kwargs
147
151
 
152
+ @classmethod
153
+ @abstractmethod
154
+ def check_lib(cls) -> bool:
155
+ pass
156
+
157
+ @classmethod
158
+ @abstractmethod
159
+ def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
160
+ pass
161
+
162
+ @classmethod
163
+ def match(cls, model_spec: EmbeddingModelSpec):
164
+ """
165
+ Return if the model_spec can be matched.
166
+ """
167
+ if not cls.check_lib():
168
+ return False
169
+ return cls.match_json(model_spec)
170
+
171
+ @abstractmethod
148
172
  def load(self):
149
- try:
150
- import sentence_transformers
151
- from sentence_transformers import SentenceTransformer
152
-
153
- if sentence_transformers.__version__ < "3.1.0":
154
- raise ValueError(
155
- "The sentence_transformers version must be greater than 3.1.0. "
156
- "Please upgrade your version via `pip install -U sentence_transformers` or refer to "
157
- "https://github.com/UKPLab/sentence-transformers"
158
- )
159
- except ImportError:
160
- error_message = "Failed to import module 'SentenceTransformer'"
161
- installation_guide = [
162
- "Please make sure 'sentence-transformers' is installed. ",
163
- "You can install it by `pip install sentence-transformers`\n",
164
- ]
165
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
166
-
167
- class XSentenceTransformer(SentenceTransformer):
168
- def to(self, *args, **kwargs):
169
- pass
170
-
171
- torch_dtype = None
172
- if torch_dtype_str := self._kwargs.get("torch_dtype"):
173
- try:
174
- torch_dtype = getattr(torch, torch_dtype_str)
175
- if torch_dtype not in [
176
- torch.float16,
177
- torch.float32,
178
- torch.bfloat16,
179
- ]:
180
- logger.warning(
181
- f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
182
- )
183
- torch_dtype = torch.float32
184
- except AttributeError:
185
- logger.warning(
186
- f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
187
- )
188
- torch_dtype = torch.float32
189
-
190
- if (
191
- "gte" in self._model_spec.model_name.lower()
192
- and "qwen2" in self._model_spec.model_name.lower()
193
- ):
194
- model_kwargs = {"device_map": "auto"}
195
- if torch_dtype:
196
- model_kwargs["torch_dtype"] = torch_dtype
197
- self._model = XSentenceTransformer(
198
- self._model_path,
199
- device=self._device,
200
- model_kwargs=model_kwargs,
201
- )
202
- elif (
203
- self._kwargs.get("hybrid_mode")
204
- and "m3" in self._model_spec.model_name.lower()
205
- ):
206
- try:
207
- from FlagEmbedding import BGEM3FlagModel
208
- except ImportError:
209
- error_message = "Failed to import module 'BGEM3FlagModel'"
210
- installation_guide = [
211
- "Please make sure 'FlagEmbedding' is installed. ",
212
- "You can install it by `pip install FlagEmbedding`\n",
213
- ]
214
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
215
-
216
- if torch_dtype and torch_dtype == torch.float16:
217
- model_kwargs = {"use_fp16": True}
218
- else:
219
- model_kwargs = {}
220
- self._model = BGEM3FlagModel(
221
- self._model_path,
222
- device=self._device,
223
- **model_kwargs,
224
- )
225
- else:
226
- model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
227
- self._model = SentenceTransformer(
228
- self._model_path,
229
- device=self._device,
230
- model_kwargs=model_kwargs,
231
- trust_remote_code=True,
232
- )
173
+ """
174
+ Load embedding model
175
+ """
233
176
 
234
177
  def _fix_langchain_openai_inputs(
235
178
  self, sentences: Union[str, List[str], Dict[str, str], List[Dict[str, str]]]
@@ -267,478 +210,52 @@ class EmbeddingModel:
267
210
  sentences = lines_decoded
268
211
  return sentences
269
212
 
213
+ @staticmethod
214
+ # copied from sentence-transformers
215
+ def _text_length(text):
216
+ if isinstance(text, dict): # {key: value} case
217
+ return len(next(iter(text.values())))
218
+ elif not hasattr(text, "__len__"): # Object has no len() method
219
+ return 1
220
+ elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
221
+ return len(text)
222
+ else:
223
+ return sum([len(t) for t in text]) # Sum of length of individual strings
224
+
225
+ @abstractmethod
270
226
  def create_embedding(
271
227
  self,
272
228
  sentences: Union[str, List[str]],
273
229
  **kwargs,
274
230
  ):
275
- sentences = self._fix_langchain_openai_inputs(sentences)
276
- model_uid = kwargs.pop("model_uid", None)
277
- from sentence_transformers import SentenceTransformer
278
-
279
- kwargs.setdefault("normalize_embeddings", True)
280
-
281
- try:
282
- from FlagEmbedding import BGEM3FlagModel
283
-
284
- @no_type_check
285
- def _encode_bgem3(
286
- model: Union[SentenceTransformer, BGEM3FlagModel],
287
- sentences: Union[str, List[str]],
288
- batch_size: int = 32,
289
- show_progress_bar: bool = None,
290
- output_value: str = "sparse_embedding",
291
- convert_to_numpy: bool = True,
292
- convert_to_tensor: bool = False,
293
- device: str = None,
294
- normalize_embeddings: bool = False,
295
- **kwargs,
296
- ):
297
- """
298
- Computes sentence embeddings with bge-m3 model
299
- Nothing special here, just replace sentence-transformer with FlagEmbedding
300
- TODO: think about how to solve the redundant code of encode method in the future
301
-
302
- :param sentences: the sentences to embed
303
- :param batch_size: the batch size used for the computation
304
- :param show_progress_bar: Output a progress bar when encode sentences
305
- :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
306
- :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
307
- :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
308
- :param device: Which torch.device to use for the computation
309
- :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
310
-
311
- :return:
312
- By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
313
- """
314
- import torch
315
- from tqdm.autonotebook import trange
316
-
317
- if show_progress_bar is None:
318
- show_progress_bar = (
319
- logger.getEffectiveLevel() == logging.INFO
320
- or logger.getEffectiveLevel() == logging.DEBUG
321
- )
322
-
323
- if convert_to_tensor:
324
- convert_to_numpy = False
325
-
326
- if output_value != "sparse_embedding":
327
- convert_to_tensor = False
328
- convert_to_numpy = False
329
-
330
- input_was_string = False
331
- if isinstance(sentences, str) or not hasattr(
332
- sentences, "__len__"
333
- ): # Cast an individual sentence to a list with length 1
334
- sentences = [sentences]
335
- input_was_string = True
336
-
337
- if device is None:
338
- # Same as SentenceTransformer.py
339
- from sentence_transformers.util import get_device_name
340
-
341
- device = get_device_name()
342
- logger.info(f"Use pytorch device_name: {device}")
343
-
344
- all_embeddings = []
345
- all_token_nums = 0
346
-
347
- # The original code does not support other inference engines
348
- def _text_length(text):
349
- if isinstance(text, dict): # {key: value} case
350
- return len(next(iter(text.values())))
351
- elif not hasattr(text, "__len__"): # Object has no len() method
352
- return 1
353
- elif len(text) == 0 or isinstance(
354
- text[0], int
355
- ): # Empty string or list of ints
356
- return len(text)
357
- else:
358
- return sum(
359
- [len(t) for t in text]
360
- ) # Sum of length of individual strings
361
-
362
- length_sorted_idx = np.argsort(
363
- [-_text_length(sen) for sen in sentences]
364
- )
365
- sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
366
-
367
- for start_index in trange(
368
- 0,
369
- len(sentences),
370
- batch_size,
371
- desc="Batches",
372
- disable=not show_progress_bar,
373
- ):
374
- sentences_batch = sentences_sorted[
375
- start_index : start_index + batch_size
376
- ]
377
-
378
- with torch.no_grad():
379
- out_features = model.encode(sentences_batch, **kwargs)
380
-
381
- if output_value == "token_embeddings":
382
- embeddings = []
383
- for token_emb, attention in zip(
384
- out_features[output_value],
385
- out_features["attention_mask"],
386
- ):
387
- last_mask_id = len(attention) - 1
388
- while (
389
- last_mask_id > 0
390
- and attention[last_mask_id].item() == 0
391
- ):
392
- last_mask_id -= 1
393
-
394
- embeddings.append(token_emb[0 : last_mask_id + 1])
395
- elif output_value is None: # Return all outputs
396
- embeddings = []
397
- for sent_idx in range(
398
- len(out_features["sentence_embedding"])
399
- ):
400
- row = {
401
- name: out_features[name][sent_idx]
402
- for name in out_features
403
- }
404
- embeddings.append(row)
405
- # for sparse embedding
406
- else:
407
- if kwargs.get("return_sparse"):
408
- embeddings = out_features["lexical_weights"]
409
- else:
410
- embeddings = out_features["dense_vecs"]
411
-
412
- if convert_to_numpy:
413
- embeddings = embeddings.cpu()
414
-
415
- all_embeddings.extend(embeddings)
416
-
417
- all_embeddings = [
418
- all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
419
- ]
420
-
421
- if convert_to_tensor:
422
- if len(all_embeddings):
423
- all_embeddings = torch.stack(all_embeddings)
424
- else:
425
- all_embeddings = torch.Tensor()
426
- elif convert_to_numpy:
427
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
428
-
429
- if input_was_string:
430
- all_embeddings = all_embeddings[0]
431
-
432
- return all_embeddings, all_token_nums
433
-
434
- except ImportError:
435
- _encode_bgem3 = None
436
-
437
- # copied from sentence-transformers, and modify it to return tokens num
438
- @no_type_check
439
- def encode(
440
- model: SentenceTransformer,
441
- sentences: Union[str, List[str]],
442
- prompt_name: Optional[str] = None,
443
- prompt: Optional[str] = None,
444
- batch_size: int = 32,
445
- show_progress_bar: bool = None,
446
- output_value: str = "sentence_embedding",
447
- convert_to_numpy: bool = True,
448
- convert_to_tensor: bool = False,
449
- device: str = None,
450
- normalize_embeddings: bool = False,
451
- **kwargs,
452
- ):
453
- """
454
- Computes sentence embeddings
455
-
456
- :param sentences: the sentences to embed
457
- :param batch_size: the batch size used for the computation
458
- :param show_progress_bar: Output a progress bar when encode sentences
459
- :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
460
- :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
461
- :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
462
- :param device: Which torch.device to use for the computation
463
- :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
464
-
465
- :return:
466
- By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
467
- """
468
- import torch
469
- from sentence_transformers.util import batch_to_device
470
- from tqdm.autonotebook import trange
471
-
472
- model.eval()
473
- if show_progress_bar is None:
474
- show_progress_bar = (
475
- logger.getEffectiveLevel() == logging.INFO
476
- or logger.getEffectiveLevel() == logging.DEBUG
477
- )
478
-
479
- if convert_to_tensor:
480
- convert_to_numpy = False
481
-
482
- if output_value != "sentence_embedding":
483
- convert_to_tensor = False
484
- convert_to_numpy = False
485
-
486
- input_was_string = False
487
- if isinstance(sentences, str) or not hasattr(
488
- sentences, "__len__"
489
- ): # Cast an individual sentence to a list with length 1
490
- sentences = [sentences]
491
- input_was_string = True
492
-
493
- if prompt is None:
494
- if prompt_name is not None:
495
- try:
496
- prompt = model.prompts[prompt_name]
497
- except KeyError:
498
- raise ValueError(
499
- f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
500
- )
501
- elif model.default_prompt_name is not None:
502
- prompt = model.prompts.get(model.default_prompt_name, None)
503
- else:
504
- if prompt_name is not None:
505
- logger.warning(
506
- "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
507
- "Ignoring the `prompt_name` in favor of `prompt`."
508
- )
509
-
510
- extra_features = {}
511
- if prompt is not None:
512
- sentences = [prompt + sentence for sentence in sentences]
513
-
514
- # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
515
- # Tracking the prompt length allow us to remove the prompt during pooling
516
- tokenized_prompt = model.tokenize([prompt])
517
- if "input_ids" in tokenized_prompt:
518
- extra_features["prompt_length"] = (
519
- tokenized_prompt["input_ids"].shape[-1] - 1
520
- )
521
-
522
- if device is None:
523
- device = model._target_device
524
-
525
- if (
526
- "gte" in self._model_spec.model_name.lower()
527
- and "qwen2" in self._model_spec.model_name.lower()
528
- ):
529
- model.to(device)
530
-
531
- all_embeddings = []
532
- all_token_nums = 0
533
- length_sorted_idx = np.argsort(
534
- [-model._text_length(sen) for sen in sentences]
535
- )
536
- sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
537
-
538
- for start_index in trange(
539
- 0,
540
- len(sentences),
541
- batch_size,
542
- desc="Batches",
543
- disable=not show_progress_bar,
544
- ):
545
- sentences_batch = sentences_sorted[
546
- start_index : start_index + batch_size
547
- ]
548
- features = model.tokenize(sentences_batch)
549
- features = batch_to_device(features, device)
550
- features.update(extra_features)
551
- # when batching, the attention mask 1 means there is a token
552
- # thus we just sum up it to get the total number of tokens
553
- if "clip" in self._model_spec.model_name.lower():
554
- if "input_ids" in features and hasattr(
555
- features["input_ids"], "numel"
556
- ):
557
- all_token_nums += features["input_ids"].numel()
558
- if "pixel_values" in features and hasattr(
559
- features["pixel_values"], "numel"
560
- ):
561
- all_token_nums += features["pixel_values"].numel()
562
- else:
563
- all_token_nums += features["attention_mask"].sum().item()
564
-
565
- with torch.no_grad():
566
- out_features = model.forward(features, **kwargs)
567
-
568
- if output_value == "token_embeddings":
569
- embeddings = []
570
- for token_emb, attention in zip(
571
- out_features[output_value], out_features["attention_mask"]
572
- ):
573
- last_mask_id = len(attention) - 1
574
- while (
575
- last_mask_id > 0 and attention[last_mask_id].item() == 0
576
- ):
577
- last_mask_id -= 1
578
-
579
- embeddings.append(token_emb[0 : last_mask_id + 1])
580
- elif output_value is None: # Return all outputs
581
- embeddings = []
582
- for sent_idx in range(len(out_features["sentence_embedding"])):
583
- row = {
584
- name: out_features[name][sent_idx]
585
- for name in out_features
586
- }
587
- embeddings.append(row)
588
- else: # Sentence embeddings
589
- embeddings = out_features[output_value]
590
- embeddings = embeddings.detach()
591
- if normalize_embeddings:
592
- embeddings = torch.nn.functional.normalize(
593
- embeddings, p=2, dim=1
594
- )
595
-
596
- # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
597
- if convert_to_numpy:
598
- embeddings = embeddings.cpu()
599
-
600
- all_embeddings.extend(embeddings)
601
-
602
- all_embeddings = [
603
- all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
604
- ]
605
-
606
- if convert_to_tensor:
607
- if len(all_embeddings):
608
- all_embeddings = torch.stack(all_embeddings)
609
- else:
610
- all_embeddings = torch.Tensor()
611
- elif convert_to_numpy:
612
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
613
-
614
- if input_was_string:
615
- all_embeddings = all_embeddings[0]
616
-
617
- return all_embeddings, all_token_nums
618
-
619
- is_bge_m3_flag_model = (
620
- self._kwargs.get("hybrid_mode")
621
- and "m3" in self._model_spec.model_name.lower()
622
- )
623
- if (
624
- "gte" in self._model_spec.model_name.lower()
625
- and "qwen2" in self._model_spec.model_name.lower()
626
- ):
627
- all_embeddings, all_token_nums = encode(
628
- self._model,
629
- sentences,
630
- prompt_name="query",
631
- convert_to_numpy=False,
632
- **kwargs,
633
- )
634
- elif is_bge_m3_flag_model:
635
- assert _encode_bgem3 is not None
636
- all_embeddings, all_token_nums = _encode_bgem3(
637
- self._model, sentences, convert_to_numpy=False, **kwargs
638
- )
639
- elif "clip" in self._model_spec.model_name.lower():
640
- import base64
641
- import re
642
- from io import BytesIO
643
-
644
- from PIL import Image
645
-
646
- def base64_to_image(base64_str: str) -> Image.Image:
647
- # base64_data = re.sub("^data:image/.+;base64,", "", base64_str)
648
- base64_data = base64_str.split(",", 1)[1]
649
- byte_data = base64.b64decode(base64_data)
650
- image_data = BytesIO(byte_data)
651
- img = Image.open(image_data)
652
- return img
653
-
654
- objs: list[dict[str, str]] = []
655
- for item in sentences:
656
- if isinstance(item, dict):
657
- if item.get("text") is not None:
658
- objs.append(item["text"])
659
- elif item.get("image") is not None:
660
- if re.match(r"^data:image/.+;base64,", item["image"]):
661
- image = base64_to_image(item["image"])
662
- objs.append(image)
663
- else:
664
- objs.append(item["image"])
665
- else:
666
- logger.error("Please check the input data.")
667
- all_embeddings, all_token_nums = encode(
668
- self._model,
669
- objs,
670
- convert_to_numpy=False,
671
- **kwargs,
672
- )
673
- else:
674
- all_embeddings, all_token_nums = encode(
675
- self._model,
676
- sentences,
677
- convert_to_numpy=False,
678
- **kwargs,
679
- )
680
- if isinstance(sentences, str):
681
- all_embeddings = [all_embeddings]
682
- embedding_list = []
683
- for index, data in enumerate(all_embeddings):
684
- if kwargs.get("return_sparse") and is_bge_m3_flag_model:
685
- embedding_list.append(
686
- EmbeddingData(
687
- index=index,
688
- object="embedding",
689
- embedding={k: float(v) for k, v in data.items()},
690
- )
691
- )
692
- else:
693
- embedding_list.append(
694
- EmbeddingData(
695
- index=index, object="embedding", embedding=data.tolist()
696
- )
697
- )
698
- usage = EmbeddingUsage(
699
- prompt_tokens=all_token_nums, total_tokens=all_token_nums
700
- )
701
- result = Embedding(
702
- object=(
703
- "list" # type: ignore
704
- if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
705
- else "dict"
706
- ),
707
- model=model_uid, # type: ignore
708
- model_replica=self._model_uid,
709
- data=embedding_list,
710
- usage=usage,
711
- )
231
+ """
232
+ Creating embeddings from sentences.
712
233
 
713
- # clean cache if possible
714
- self._counter += 1
715
- if (
716
- self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
717
- or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
718
- ):
719
- logger.debug(
720
- "Empty embedding cache, calling count %s, all_token_nums %s",
721
- self._counter,
722
- all_token_nums,
723
- )
724
- gc.collect()
725
- empty_cache()
234
+ Parameters
235
+ ----------
236
+ sentences: Union[str, List[str]]
237
+ Input text to embed, encoded as a string or array of tokens.
238
+ To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
726
239
 
727
- return result
240
+ Returns
241
+ -------
242
+ Embedding
243
+ The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.
244
+ """
728
245
 
729
246
  def convert_ids_to_tokens(
730
247
  self,
731
248
  batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
732
249
  **kwargs,
733
250
  ) -> Union[List[str]]:
734
- batch_decoded_texts: List[str] = []
735
-
251
+ """
252
+ Convert token ids to tokens
253
+ """
736
254
  assert self._model is not None
737
-
738
255
  if isinstance(batch_token_ids, (int, str)):
739
- return self._model.tokenizer.convert_ids_to_tokens(
740
- [int(str(batch_token_ids))]
741
- )[0]
256
+ return self._tokenizer.decode([int(str(batch_token_ids))])[0]
257
+
258
+ batch_decoded_texts: List[str] = []
742
259
 
743
260
  # check if it's a nested list
744
261
  if (
@@ -747,58 +264,37 @@ class EmbeddingModel:
747
264
  and isinstance(batch_token_ids[0], list)
748
265
  ):
749
266
  for token_ids in batch_token_ids:
750
- token_ids = [int(token_id) for token_id in token_ids]
267
+ token_ids = [int(token_id) for token_id in token_ids] # type: ignore
751
268
  batch_decoded_texts.append(
752
- self._model.tokenizer.convert_ids_to_tokens(token_ids)
269
+ self._tokenizer.convert_ids_to_tokens(token_ids)
753
270
  )
754
271
  else:
755
- batch_token_ids = [int(token_id) for token_id in batch_token_ids]
756
- batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens(
757
- batch_token_ids
758
- )
272
+ batch_token_ids = [int(token_id) for token_id in batch_token_ids] # type: ignore
273
+ batch_decoded_texts = self._tokenizer.convert_ids_to_tokens(batch_token_ids)
759
274
  return batch_decoded_texts
760
275
 
761
-
762
- def match_embedding(
763
- model_name: str,
764
- download_hub: Optional[
765
- Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
766
- ] = None,
767
- ) -> EmbeddingModelSpec:
768
- from ..utils import download_from_modelscope
769
- from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
770
- from .custom import get_user_defined_embeddings
771
-
772
- # first, check whether it is a user-defined embedding model
773
- for model_spec in get_user_defined_embeddings():
774
- if model_name == model_spec.model_name:
775
- return model_spec
776
-
777
- if download_hub == "modelscope" and model_name in MODELSCOPE_EMBEDDING_MODELS:
778
- logger.debug(f"Embedding model {model_name} found in ModelScope.")
779
- return MODELSCOPE_EMBEDDING_MODELS[model_name]
780
- elif download_hub == "huggingface" and model_name in BUILTIN_EMBEDDING_MODELS:
781
- logger.debug(f"Embedding model {model_name} found in Huggingface.")
782
- return BUILTIN_EMBEDDING_MODELS[model_name]
783
- elif download_from_modelscope() and model_name in MODELSCOPE_EMBEDDING_MODELS:
784
- logger.debug(f"Embedding model {model_name} found in ModelScope.")
785
- return MODELSCOPE_EMBEDDING_MODELS[model_name]
786
- elif model_name in BUILTIN_EMBEDDING_MODELS:
787
- logger.debug(f"Embedding model {model_name} found in Huggingface.")
788
- return BUILTIN_EMBEDDING_MODELS[model_name]
789
- else:
790
- raise ValueError(
791
- f"Embedding model {model_name} not found, available"
792
- f"Huggingface: {BUILTIN_EMBEDDING_MODELS.keys()}"
793
- f"ModelScope: {MODELSCOPE_EMBEDDING_MODELS.keys()}"
794
- )
276
+ def _clean_cache_if_needed(self, all_token_nums: int):
277
+ # clean cache if possible
278
+ self._counter += 1
279
+ if (
280
+ self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
281
+ or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
282
+ ):
283
+ logger.debug(
284
+ "Empty embedding cache, calling count %s, all_token_nums %s",
285
+ self._counter,
286
+ all_token_nums,
287
+ )
288
+ gc.collect()
289
+ empty_cache()
795
290
 
796
291
 
797
292
  def create_embedding_model_instance(
798
293
  subpool_addr: str,
799
- devices: List[str],
294
+ devices: Optional[List[str]],
800
295
  model_uid: str,
801
296
  model_name: str,
297
+ model_engine: Optional[str],
802
298
  download_hub: Optional[
803
299
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
804
300
  ] = None,
@@ -809,7 +305,20 @@ def create_embedding_model_instance(
809
305
  if model_path is None:
810
306
  model_path = cache(model_spec)
811
307
 
812
- model = EmbeddingModel(model_uid, model_path, model_spec, **kwargs)
308
+ if model_engine is None:
309
+ # unlike LLM and for compatibility
310
+ # we use sentence_transformers as the default engine for all models
311
+ model_engine = "sentence_transformers"
312
+
313
+ from .embed_family import check_engine_by_model_name_and_engine
314
+
315
+ embedding_cls = check_engine_by_model_name_and_engine(
316
+ model_name,
317
+ model_engine,
318
+ )
319
+ devices = devices or ["cpu"]
320
+ # model class should be one of flag, fastembed, sentence_transformers
321
+ model = embedding_cls(model_uid, model_path, model_spec, **kwargs)
813
322
  model_description = EmbeddingModelDescription(
814
323
  subpool_addr, devices, model_spec, model_path=model_path
815
324
  )