xinference 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (76) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +79 -2
  3. xinference/client/restful/restful_client.py +64 -2
  4. xinference/core/media_interface.py +123 -0
  5. xinference/core/model.py +31 -0
  6. xinference/core/supervisor.py +8 -17
  7. xinference/core/worker.py +5 -17
  8. xinference/deploy/cmdline.py +6 -2
  9. xinference/model/audio/chattts.py +24 -39
  10. xinference/model/audio/cosyvoice.py +18 -30
  11. xinference/model/audio/funasr.py +42 -0
  12. xinference/model/audio/model_spec.json +18 -0
  13. xinference/model/audio/model_spec_modelscope.json +19 -1
  14. xinference/model/audio/utils.py +75 -0
  15. xinference/model/core.py +1 -0
  16. xinference/model/embedding/__init__.py +74 -18
  17. xinference/model/embedding/core.py +98 -597
  18. xinference/model/embedding/embed_family.py +133 -0
  19. xinference/model/embedding/flag/__init__.py +13 -0
  20. xinference/model/embedding/flag/core.py +282 -0
  21. xinference/model/embedding/model_spec.json +24 -0
  22. xinference/model/embedding/model_spec_modelscope.json +24 -0
  23. xinference/model/embedding/sentence_transformers/__init__.py +13 -0
  24. xinference/model/embedding/sentence_transformers/core.py +399 -0
  25. xinference/model/embedding/vllm/__init__.py +0 -0
  26. xinference/model/embedding/vllm/core.py +95 -0
  27. xinference/model/image/model_spec.json +20 -2
  28. xinference/model/image/model_spec_modelscope.json +21 -2
  29. xinference/model/image/stable_diffusion/core.py +144 -53
  30. xinference/model/llm/llama_cpp/memory.py +4 -2
  31. xinference/model/llm/llm_family.json +57 -0
  32. xinference/model/llm/llm_family_modelscope.json +61 -0
  33. xinference/model/llm/sglang/core.py +4 -0
  34. xinference/model/llm/utils.py +11 -0
  35. xinference/model/llm/vllm/core.py +3 -0
  36. xinference/model/rerank/core.py +86 -4
  37. xinference/model/rerank/model_spec.json +24 -0
  38. xinference/model/rerank/model_spec_modelscope.json +24 -0
  39. xinference/model/rerank/utils.py +4 -3
  40. xinference/model/utils.py +38 -1
  41. xinference/model/video/diffusers.py +65 -3
  42. xinference/model/video/model_spec.json +31 -4
  43. xinference/model/video/model_spec_modelscope.json +32 -4
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.013f296b.css +2 -0
  47. xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
  49. xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
  56. xinference/web/ui/src/locales/en.json +18 -7
  57. xinference/web/ui/src/locales/ja.json +224 -0
  58. xinference/web/ui/src/locales/ko.json +224 -0
  59. xinference/web/ui/src/locales/zh.json +18 -7
  60. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/METADATA +9 -8
  61. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/RECORD +66 -57
  62. xinference/web/ui/build/static/css/main.337afe76.css +0 -2
  63. xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
  64. xinference/web/ui/build/static/js/main.ddf9eaee.js +0 -3
  65. xinference/web/ui/build/static/js/main.ddf9eaee.js.map +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
  72. /xinference/web/ui/build/static/js/{main.ddf9eaee.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
  73. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/WHEEL +0 -0
  74. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
  75. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
  76. {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
@@ -16,16 +16,13 @@ import gc
16
16
  import logging
17
17
  import os
18
18
  from collections import defaultdict
19
- from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
20
-
21
- import numpy as np
22
- import torch
19
+ from typing import Dict, List, Literal, Optional, Tuple, Union
23
20
 
24
21
  from ..._compat import ROOT_KEY, ErrorWrapper, ValidationError
25
22
  from ...device_utils import empty_cache
26
- from ...types import Embedding, EmbeddingData, EmbeddingUsage
27
23
  from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
28
24
  from ..utils import get_cache_dir, is_model_cached
25
+ from .embed_family import match_embedding
29
26
 
30
27
  logger = logging.getLogger(__name__)
31
28
 
@@ -49,6 +46,7 @@ def get_embedding_model_descriptions():
49
46
  return copy.deepcopy(EMBEDDING_MODEL_DESCRIPTIONS)
50
47
 
51
48
 
49
+ # this class define the basic info of embedding model
52
50
  class EmbeddingModelSpec(CacheableModelSpec):
53
51
  model_name: str
54
52
  dimensions: int
@@ -128,7 +126,11 @@ def get_cache_status(
128
126
  return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
129
127
 
130
128
 
131
- class EmbeddingModel:
129
+ import abc
130
+ from abc import abstractmethod
131
+
132
+
133
+ class EmbeddingModel(abc.ABC):
132
134
  def __init__(
133
135
  self,
134
136
  model_uid: str,
@@ -141,95 +143,36 @@ class EmbeddingModel:
141
143
  self._model_path = model_path
142
144
  self._device = device
143
145
  self._model = None
146
+ self._tokenizer = None
144
147
  self._counter = 0
145
148
  self._model_spec = model_spec
149
+ self._model_name = self._model_spec.model_name
146
150
  self._kwargs = kwargs
147
151
 
152
+ @classmethod
153
+ @abstractmethod
154
+ def check_lib(cls) -> bool:
155
+ pass
156
+
157
+ @classmethod
158
+ @abstractmethod
159
+ def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
160
+ pass
161
+
162
+ @classmethod
163
+ def match(cls, model_spec: EmbeddingModelSpec):
164
+ """
165
+ Return if the model_spec can be matched.
166
+ """
167
+ if not cls.check_lib():
168
+ return False
169
+ return cls.match_json(model_spec)
170
+
171
+ @abstractmethod
148
172
  def load(self):
149
- try:
150
- import sentence_transformers
151
- from sentence_transformers import SentenceTransformer
152
-
153
- if sentence_transformers.__version__ < "3.1.0":
154
- raise ValueError(
155
- "The sentence_transformers version must be greater than 3.1.0. "
156
- "Please upgrade your version via `pip install -U sentence_transformers` or refer to "
157
- "https://github.com/UKPLab/sentence-transformers"
158
- )
159
- except ImportError:
160
- error_message = "Failed to import module 'SentenceTransformer'"
161
- installation_guide = [
162
- "Please make sure 'sentence-transformers' is installed. ",
163
- "You can install it by `pip install sentence-transformers`\n",
164
- ]
165
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
166
-
167
- class XSentenceTransformer(SentenceTransformer):
168
- def to(self, *args, **kwargs):
169
- pass
170
-
171
- torch_dtype = None
172
- if torch_dtype_str := self._kwargs.get("torch_dtype"):
173
- try:
174
- torch_dtype = getattr(torch, torch_dtype_str)
175
- if torch_dtype not in [
176
- torch.float16,
177
- torch.float32,
178
- torch.bfloat16,
179
- ]:
180
- logger.warning(
181
- f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
182
- )
183
- torch_dtype = torch.float32
184
- except AttributeError:
185
- logger.warning(
186
- f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
187
- )
188
- torch_dtype = torch.float32
189
-
190
- if (
191
- "gte" in self._model_spec.model_name.lower()
192
- and "qwen2" in self._model_spec.model_name.lower()
193
- ):
194
- model_kwargs = {"device_map": "auto"}
195
- if torch_dtype:
196
- model_kwargs["torch_dtype"] = torch_dtype
197
- self._model = XSentenceTransformer(
198
- self._model_path,
199
- device=self._device,
200
- model_kwargs=model_kwargs,
201
- )
202
- elif (
203
- self._kwargs.get("hybrid_mode")
204
- and "m3" in self._model_spec.model_name.lower()
205
- ):
206
- try:
207
- from FlagEmbedding import BGEM3FlagModel
208
- except ImportError:
209
- error_message = "Failed to import module 'BGEM3FlagModel'"
210
- installation_guide = [
211
- "Please make sure 'FlagEmbedding' is installed. ",
212
- "You can install it by `pip install FlagEmbedding`\n",
213
- ]
214
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
215
-
216
- if torch_dtype and torch_dtype == torch.float16:
217
- model_kwargs = {"use_fp16": True}
218
- else:
219
- model_kwargs = {}
220
- self._model = BGEM3FlagModel(
221
- self._model_path,
222
- device=self._device,
223
- **model_kwargs,
224
- )
225
- else:
226
- model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
227
- self._model = SentenceTransformer(
228
- self._model_path,
229
- device=self._device,
230
- model_kwargs=model_kwargs,
231
- trust_remote_code=True,
232
- )
173
+ """
174
+ Load embedding model
175
+ """
233
176
 
234
177
  def _fix_langchain_openai_inputs(
235
178
  self, sentences: Union[str, List[str], Dict[str, str], List[Dict[str, str]]]
@@ -267,486 +210,52 @@ class EmbeddingModel:
267
210
  sentences = lines_decoded
268
211
  return sentences
269
212
 
213
+ @staticmethod
214
+ # copied from sentence-transformers
215
+ def _text_length(text):
216
+ if isinstance(text, dict): # {key: value} case
217
+ return len(next(iter(text.values())))
218
+ elif not hasattr(text, "__len__"): # Object has no len() method
219
+ return 1
220
+ elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
221
+ return len(text)
222
+ else:
223
+ return sum([len(t) for t in text]) # Sum of length of individual strings
224
+
225
+ @abstractmethod
270
226
  def create_embedding(
271
227
  self,
272
228
  sentences: Union[str, List[str]],
273
229
  **kwargs,
274
230
  ):
275
- sentences = self._fix_langchain_openai_inputs(sentences)
276
- model_uid = kwargs.pop("model_uid", None)
277
- from sentence_transformers import SentenceTransformer
278
-
279
- kwargs.setdefault("normalize_embeddings", True)
280
-
281
- try:
282
- from FlagEmbedding import BGEM3FlagModel
283
-
284
- @no_type_check
285
- def _encode_bgem3(
286
- model: Union[SentenceTransformer, BGEM3FlagModel],
287
- sentences: Union[str, List[str]],
288
- batch_size: int = 32,
289
- show_progress_bar: bool = None,
290
- output_value: str = "sparse_embedding",
291
- convert_to_numpy: bool = True,
292
- convert_to_tensor: bool = False,
293
- device: str = None,
294
- normalize_embeddings: bool = False,
295
- **kwargs,
296
- ):
297
- """
298
- Computes sentence embeddings with bge-m3 model
299
- Nothing special here, just replace sentence-transformer with FlagEmbedding
300
- TODO: think about how to solve the redundant code of encode method in the future
301
-
302
- :param sentences: the sentences to embed
303
- :param batch_size: the batch size used for the computation
304
- :param show_progress_bar: Output a progress bar when encode sentences
305
- :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
306
- :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
307
- :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
308
- :param device: Which torch.device to use for the computation
309
- :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
310
-
311
- :return:
312
- By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
313
- """
314
- import torch
315
- from tqdm.autonotebook import trange
316
-
317
- if show_progress_bar is None:
318
- show_progress_bar = (
319
- logger.getEffectiveLevel() == logging.INFO
320
- or logger.getEffectiveLevel() == logging.DEBUG
321
- )
231
+ """
232
+ Creating embeddings from sentences.
322
233
 
323
- if convert_to_tensor:
324
- convert_to_numpy = False
325
-
326
- if output_value != "sparse_embedding":
327
- convert_to_tensor = False
328
- convert_to_numpy = False
329
-
330
- input_was_string = False
331
- if isinstance(sentences, str) or not hasattr(
332
- sentences, "__len__"
333
- ): # Cast an individual sentence to a list with length 1
334
- sentences = [sentences]
335
- input_was_string = True
336
-
337
- if device is None:
338
- # Same as SentenceTransformer.py
339
- from sentence_transformers.util import get_device_name
340
-
341
- device = get_device_name()
342
- logger.info(f"Use pytorch device_name: {device}")
343
-
344
- all_embeddings = []
345
- all_token_nums = 0
346
-
347
- # The original code does not support other inference engines
348
- def _text_length(text):
349
- if isinstance(text, dict): # {key: value} case
350
- return len(next(iter(text.values())))
351
- elif not hasattr(text, "__len__"): # Object has no len() method
352
- return 1
353
- elif len(text) == 0 or isinstance(
354
- text[0], int
355
- ): # Empty string or list of ints
356
- return len(text)
357
- else:
358
- return sum(
359
- [len(t) for t in text]
360
- ) # Sum of length of individual strings
361
-
362
- length_sorted_idx = np.argsort(
363
- [-_text_length(sen) for sen in sentences]
364
- )
365
- sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
366
-
367
- for start_index in trange(
368
- 0,
369
- len(sentences),
370
- batch_size,
371
- desc="Batches",
372
- disable=not show_progress_bar,
373
- ):
374
- sentences_batch = sentences_sorted[
375
- start_index : start_index + batch_size
376
- ]
377
-
378
- with torch.no_grad():
379
- out_features = model.encode(sentences_batch, **kwargs)
380
-
381
- if output_value == "token_embeddings":
382
- embeddings = []
383
- for token_emb, attention in zip(
384
- out_features[output_value],
385
- out_features["attention_mask"],
386
- ):
387
- last_mask_id = len(attention) - 1
388
- while (
389
- last_mask_id > 0
390
- and attention[last_mask_id].item() == 0
391
- ):
392
- last_mask_id -= 1
393
-
394
- embeddings.append(token_emb[0 : last_mask_id + 1])
395
- elif output_value is None: # Return all outputs
396
- embeddings = []
397
- for sent_idx in range(
398
- len(out_features["sentence_embedding"])
399
- ):
400
- row = {
401
- name: out_features[name][sent_idx]
402
- for name in out_features
403
- }
404
- embeddings.append(row)
405
- # for sparse embedding
406
- else:
407
- if kwargs.get("return_sparse"):
408
- embeddings = out_features["lexical_weights"]
409
- else:
410
- embeddings = out_features["dense_vecs"]
411
-
412
- if convert_to_numpy:
413
- embeddings = embeddings.cpu()
414
-
415
- all_embeddings.extend(embeddings)
416
-
417
- all_embeddings = [
418
- all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
419
- ]
420
-
421
- if convert_to_tensor:
422
- if len(all_embeddings):
423
- all_embeddings = torch.stack(all_embeddings)
424
- else:
425
- all_embeddings = torch.Tensor()
426
- elif convert_to_numpy:
427
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
428
-
429
- if input_was_string:
430
- all_embeddings = all_embeddings[0]
431
-
432
- return all_embeddings, all_token_nums
433
-
434
- except ImportError:
435
- _encode_bgem3 = None
436
-
437
- # copied from sentence-transformers, and modify it to return tokens num
438
- @no_type_check
439
- def encode(
440
- model: SentenceTransformer,
441
- sentences: Union[str, List[str]],
442
- prompt_name: Optional[str] = None,
443
- prompt: Optional[str] = None,
444
- batch_size: int = 32,
445
- show_progress_bar: bool = None,
446
- output_value: str = "sentence_embedding",
447
- convert_to_numpy: bool = True,
448
- convert_to_tensor: bool = False,
449
- device: str = None,
450
- normalize_embeddings: bool = False,
451
- **kwargs,
452
- ):
453
- """
454
- Computes sentence embeddings
455
-
456
- :param sentences: the sentences to embed
457
- :param batch_size: the batch size used for the computation
458
- :param show_progress_bar: Output a progress bar when encode sentences
459
- :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
460
- :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
461
- :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
462
- :param device: Which torch.device to use for the computation
463
- :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
464
-
465
- :return:
466
- By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
467
- """
468
- import torch
469
- from sentence_transformers.util import batch_to_device
470
- from tqdm.autonotebook import trange
471
-
472
- model.eval()
473
- if show_progress_bar is None:
474
- show_progress_bar = (
475
- logger.getEffectiveLevel() == logging.INFO
476
- or logger.getEffectiveLevel() == logging.DEBUG
477
- )
234
+ Parameters
235
+ ----------
236
+ sentences: Union[str, List[str]]
237
+ Input text to embed, encoded as a string or array of tokens.
238
+ To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
478
239
 
479
- if convert_to_tensor:
480
- convert_to_numpy = False
481
-
482
- if output_value != "sentence_embedding":
483
- convert_to_tensor = False
484
- convert_to_numpy = False
485
-
486
- input_was_string = False
487
- if isinstance(sentences, str) or not hasattr(
488
- sentences, "__len__"
489
- ): # Cast an individual sentence to a list with length 1
490
- sentences = [sentences]
491
- input_was_string = True
492
-
493
- if prompt is None:
494
- if prompt_name is not None:
495
- try:
496
- prompt = model.prompts[prompt_name]
497
- except KeyError:
498
- raise ValueError(
499
- f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
500
- )
501
- elif model.default_prompt_name is not None:
502
- prompt = model.prompts.get(model.default_prompt_name, None)
503
- else:
504
- if prompt_name is not None:
505
- logger.warning(
506
- "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
507
- "Ignoring the `prompt_name` in favor of `prompt`."
508
- )
509
-
510
- extra_features = {}
511
- if prompt is not None:
512
- sentences = [prompt + sentence for sentence in sentences]
513
-
514
- # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
515
- # Tracking the prompt length allow us to remove the prompt during pooling
516
- tokenized_prompt = model.tokenize([prompt])
517
- if "input_ids" in tokenized_prompt:
518
- extra_features["prompt_length"] = (
519
- tokenized_prompt["input_ids"].shape[-1] - 1
520
- )
521
-
522
- if device is None:
523
- device = model._target_device
524
-
525
- if (
526
- "gte" in self._model_spec.model_name.lower()
527
- and "qwen2" in self._model_spec.model_name.lower()
528
- ):
529
- model.to(device)
530
-
531
- all_embeddings = []
532
- all_token_nums = 0
533
- length_sorted_idx = np.argsort(
534
- [-model._text_length(sen) for sen in sentences]
535
- )
536
- sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
537
-
538
- for start_index in trange(
539
- 0,
540
- len(sentences),
541
- batch_size,
542
- desc="Batches",
543
- disable=not show_progress_bar,
544
- ):
545
- sentences_batch = sentences_sorted[
546
- start_index : start_index + batch_size
547
- ]
548
- features = model.tokenize(sentences_batch)
549
- features = batch_to_device(features, device)
550
- features.update(extra_features)
551
- # when batching, the attention mask 1 means there is a token
552
- # thus we just sum up it to get the total number of tokens
553
- if "clip" in self._model_spec.model_name.lower():
554
- if "input_ids" in features and hasattr(
555
- features["input_ids"], "numel"
556
- ):
557
- all_token_nums += features["input_ids"].numel()
558
- if "pixel_values" in features and hasattr(
559
- features["pixel_values"], "numel"
560
- ):
561
- all_token_nums += features["pixel_values"].numel()
562
- else:
563
- all_token_nums += features["attention_mask"].sum().item()
564
-
565
- with torch.no_grad():
566
- out_features = model.forward(features, **kwargs)
567
-
568
- if output_value == "token_embeddings":
569
- embeddings = []
570
- for token_emb, attention in zip(
571
- out_features[output_value], out_features["attention_mask"]
572
- ):
573
- last_mask_id = len(attention) - 1
574
- while (
575
- last_mask_id > 0 and attention[last_mask_id].item() == 0
576
- ):
577
- last_mask_id -= 1
578
-
579
- embeddings.append(token_emb[0 : last_mask_id + 1])
580
- elif output_value is None: # Return all outputs
581
- embeddings = []
582
- for sent_idx in range(len(out_features["sentence_embedding"])):
583
- row = {
584
- name: out_features[name][sent_idx]
585
- for name in out_features
586
- }
587
- embeddings.append(row)
588
- else: # Sentence embeddings
589
- embeddings = out_features[output_value]
590
- embeddings = embeddings.detach()
591
- if normalize_embeddings:
592
- embeddings = torch.nn.functional.normalize(
593
- embeddings, p=2, dim=1
594
- )
595
-
596
- # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
597
- if convert_to_numpy:
598
- embeddings = embeddings.cpu()
599
-
600
- all_embeddings.extend(embeddings)
601
-
602
- all_embeddings = [
603
- all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
604
- ]
605
-
606
- if convert_to_tensor:
607
- if len(all_embeddings):
608
- all_embeddings = torch.stack(all_embeddings)
609
- else:
610
- all_embeddings = torch.Tensor()
611
- elif convert_to_numpy:
612
- all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
613
-
614
- if input_was_string:
615
- all_embeddings = all_embeddings[0]
616
-
617
- return all_embeddings, all_token_nums
618
-
619
- is_bge_m3_flag_model = (
620
- self._kwargs.get("hybrid_mode")
621
- and "m3" in self._model_spec.model_name.lower()
622
- )
623
- if (
624
- "gte" in self._model_spec.model_name.lower()
625
- and "qwen2" in self._model_spec.model_name.lower()
626
- ):
627
- all_embeddings, all_token_nums = encode(
628
- self._model,
629
- sentences,
630
- prompt_name="query",
631
- convert_to_numpy=False,
632
- **kwargs,
633
- )
634
- elif is_bge_m3_flag_model:
635
- assert _encode_bgem3 is not None
636
- all_embeddings, all_token_nums = _encode_bgem3(
637
- self._model, sentences, convert_to_numpy=False, **kwargs
638
- )
639
- elif "clip" in self._model_spec.model_name.lower():
640
- import base64
641
- import re
642
- from io import BytesIO
643
-
644
- from PIL import Image
645
-
646
- def base64_to_image(base64_str: str) -> Image.Image:
647
- # base64_data = re.sub("^data:image/.+;base64,", "", base64_str)
648
- base64_data = base64_str.split(",", 1)[1]
649
- byte_data = base64.b64decode(base64_data)
650
- image_data = BytesIO(byte_data)
651
- img = Image.open(image_data)
652
- return img
653
-
654
- objs: list[str] = []
655
- if isinstance(sentences, str):
656
- objs.append(sentences)
657
- else:
658
- for item in sentences:
659
- if isinstance(item, dict):
660
- if item.get("text") is not None:
661
- objs.append(item["text"])
662
- elif item.get("image") is not None:
663
- if re.match(r"^data:image/.+;base64,", item["image"]):
664
- image = base64_to_image(item["image"])
665
- objs.append(image)
666
- else:
667
- objs.append(item["image"])
668
- else:
669
- raise ValueError("Please check the input data.")
670
- elif isinstance(item, str):
671
- objs.append(item)
672
- else:
673
- raise ValueError("Please check the input data.")
674
-
675
- all_embeddings, all_token_nums = encode(
676
- self._model,
677
- objs,
678
- convert_to_numpy=False,
679
- **kwargs,
680
- )
681
- else:
682
- all_embeddings, all_token_nums = encode(
683
- self._model,
684
- sentences,
685
- convert_to_numpy=False,
686
- **kwargs,
687
- )
688
- if isinstance(sentences, str):
689
- all_embeddings = [all_embeddings]
690
- embedding_list = []
691
- for index, data in enumerate(all_embeddings):
692
- if kwargs.get("return_sparse") and is_bge_m3_flag_model:
693
- embedding_list.append(
694
- EmbeddingData(
695
- index=index,
696
- object="embedding",
697
- embedding={k: float(v) for k, v in data.items()},
698
- )
699
- )
700
- else:
701
- embedding_list.append(
702
- EmbeddingData(
703
- index=index, object="embedding", embedding=data.tolist()
704
- )
705
- )
706
- usage = EmbeddingUsage(
707
- prompt_tokens=all_token_nums, total_tokens=all_token_nums
708
- )
709
- result = Embedding(
710
- object=(
711
- "list" # type: ignore
712
- if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
713
- else "dict"
714
- ),
715
- model=model_uid, # type: ignore
716
- model_replica=self._model_uid,
717
- data=embedding_list,
718
- usage=usage,
719
- )
720
-
721
- # clean cache if possible
722
- self._counter += 1
723
- if (
724
- self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
725
- or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
726
- ):
727
- logger.debug(
728
- "Empty embedding cache, calling count %s, all_token_nums %s",
729
- self._counter,
730
- all_token_nums,
731
- )
732
- gc.collect()
733
- empty_cache()
734
-
735
- return result
240
+ Returns
241
+ -------
242
+ Embedding
243
+ The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.
244
+ """
736
245
 
737
246
  def convert_ids_to_tokens(
738
247
  self,
739
248
  batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
740
249
  **kwargs,
741
250
  ) -> Union[List[str]]:
742
- batch_decoded_texts: List[str] = []
743
-
251
+ """
252
+ Convert token ids to tokens
253
+ """
744
254
  assert self._model is not None
745
-
746
255
  if isinstance(batch_token_ids, (int, str)):
747
- return self._model.tokenizer.convert_ids_to_tokens(
748
- [int(str(batch_token_ids))]
749
- )[0]
256
+ return self._tokenizer.decode([int(str(batch_token_ids))])[0]
257
+
258
+ batch_decoded_texts: List[str] = []
750
259
 
751
260
  # check if it's a nested list
752
261
  if (
@@ -755,58 +264,37 @@ class EmbeddingModel:
755
264
  and isinstance(batch_token_ids[0], list)
756
265
  ):
757
266
  for token_ids in batch_token_ids:
758
- token_ids = [int(token_id) for token_id in token_ids]
267
+ token_ids = [int(token_id) for token_id in token_ids] # type: ignore
759
268
  batch_decoded_texts.append(
760
- self._model.tokenizer.convert_ids_to_tokens(token_ids)
269
+ self._tokenizer.convert_ids_to_tokens(token_ids)
761
270
  )
762
271
  else:
763
- batch_token_ids = [int(token_id) for token_id in batch_token_ids]
764
- batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens(
765
- batch_token_ids
766
- )
272
+ batch_token_ids = [int(token_id) for token_id in batch_token_ids] # type: ignore
273
+ batch_decoded_texts = self._tokenizer.convert_ids_to_tokens(batch_token_ids)
767
274
  return batch_decoded_texts
768
275
 
769
-
770
- def match_embedding(
771
- model_name: str,
772
- download_hub: Optional[
773
- Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
774
- ] = None,
775
- ) -> EmbeddingModelSpec:
776
- from ..utils import download_from_modelscope
777
- from . import BUILTIN_EMBEDDING_MODELS, MODELSCOPE_EMBEDDING_MODELS
778
- from .custom import get_user_defined_embeddings
779
-
780
- # first, check whether it is a user-defined embedding model
781
- for model_spec in get_user_defined_embeddings():
782
- if model_name == model_spec.model_name:
783
- return model_spec
784
-
785
- if download_hub == "modelscope" and model_name in MODELSCOPE_EMBEDDING_MODELS:
786
- logger.debug(f"Embedding model {model_name} found in ModelScope.")
787
- return MODELSCOPE_EMBEDDING_MODELS[model_name]
788
- elif download_hub == "huggingface" and model_name in BUILTIN_EMBEDDING_MODELS:
789
- logger.debug(f"Embedding model {model_name} found in Huggingface.")
790
- return BUILTIN_EMBEDDING_MODELS[model_name]
791
- elif download_from_modelscope() and model_name in MODELSCOPE_EMBEDDING_MODELS:
792
- logger.debug(f"Embedding model {model_name} found in ModelScope.")
793
- return MODELSCOPE_EMBEDDING_MODELS[model_name]
794
- elif model_name in BUILTIN_EMBEDDING_MODELS:
795
- logger.debug(f"Embedding model {model_name} found in Huggingface.")
796
- return BUILTIN_EMBEDDING_MODELS[model_name]
797
- else:
798
- raise ValueError(
799
- f"Embedding model {model_name} not found, available"
800
- f"Huggingface: {BUILTIN_EMBEDDING_MODELS.keys()}"
801
- f"ModelScope: {MODELSCOPE_EMBEDDING_MODELS.keys()}"
802
- )
276
+ def _clean_cache_if_needed(self, all_token_nums: int):
277
+ # clean cache if possible
278
+ self._counter += 1
279
+ if (
280
+ self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
281
+ or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
282
+ ):
283
+ logger.debug(
284
+ "Empty embedding cache, calling count %s, all_token_nums %s",
285
+ self._counter,
286
+ all_token_nums,
287
+ )
288
+ gc.collect()
289
+ empty_cache()
803
290
 
804
291
 
805
292
  def create_embedding_model_instance(
806
293
  subpool_addr: str,
807
- devices: List[str],
294
+ devices: Optional[List[str]],
808
295
  model_uid: str,
809
296
  model_name: str,
297
+ model_engine: Optional[str],
810
298
  download_hub: Optional[
811
299
  Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
812
300
  ] = None,
@@ -817,7 +305,20 @@ def create_embedding_model_instance(
817
305
  if model_path is None:
818
306
  model_path = cache(model_spec)
819
307
 
820
- model = EmbeddingModel(model_uid, model_path, model_spec, **kwargs)
308
+ if model_engine is None:
309
+ # unlike LLM and for compatibility
310
+ # we use sentence_transformers as the default engine for all models
311
+ model_engine = "sentence_transformers"
312
+
313
+ from .embed_family import check_engine_by_model_name_and_engine
314
+
315
+ embedding_cls = check_engine_by_model_name_and_engine(
316
+ model_name,
317
+ model_engine,
318
+ )
319
+ devices = devices or ["cpu"]
320
+ # model class should be one of flag, fastembed, sentence_transformers
321
+ model = embedding_cls(model_uid, model_path, model_spec, **kwargs)
821
322
  model_description = EmbeddingModelDescription(
822
323
  subpool_addr, devices, model_spec, model_path=model_path
823
324
  )