xinference 1.6.0.post1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +79 -2
- xinference/client/restful/restful_client.py +65 -3
- xinference/conftest.py +0 -7
- xinference/core/media_interface.py +132 -8
- xinference/core/model.py +44 -6
- xinference/core/scheduler.py +1 -10
- xinference/core/supervisor.py +8 -17
- xinference/core/worker.py +5 -27
- xinference/deploy/cmdline.py +6 -2
- xinference/model/audio/chattts.py +24 -39
- xinference/model/audio/cosyvoice.py +18 -30
- xinference/model/audio/funasr.py +42 -0
- xinference/model/audio/model_spec.json +71 -1
- xinference/model/audio/model_spec_modelscope.json +76 -2
- xinference/model/audio/utils.py +75 -0
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +74 -18
- xinference/model/embedding/core.py +98 -589
- xinference/model/embedding/embed_family.py +133 -0
- xinference/{thirdparty/omnilmm/train → model/embedding/flag}/__init__.py +1 -1
- xinference/model/embedding/flag/core.py +282 -0
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/sentence_transformers/__init__.py +13 -0
- xinference/model/embedding/sentence_transformers/core.py +399 -0
- xinference/model/embedding/vllm/core.py +95 -0
- xinference/model/image/model_spec.json +30 -3
- xinference/model/image/model_spec_modelscope.json +41 -2
- xinference/model/image/stable_diffusion/core.py +144 -53
- xinference/model/llm/__init__.py +6 -54
- xinference/model/llm/core.py +19 -5
- xinference/model/llm/llama_cpp/core.py +59 -3
- xinference/model/llm/llama_cpp/memory.py +457 -0
- xinference/model/llm/llm_family.json +247 -402
- xinference/model/llm/llm_family.py +88 -16
- xinference/model/llm/llm_family_modelscope.json +260 -421
- xinference/model/llm/llm_family_openmind_hub.json +0 -34
- xinference/model/llm/sglang/core.py +8 -0
- xinference/model/llm/transformers/__init__.py +27 -6
- xinference/model/llm/transformers/chatglm.py +4 -2
- xinference/model/llm/transformers/core.py +49 -28
- xinference/model/llm/transformers/deepseek_v2.py +6 -49
- xinference/model/llm/transformers/gemma3.py +119 -164
- xinference/model/llm/transformers/multimodal/__init__.py +13 -0
- xinference/model/llm/transformers/{cogagent.py → multimodal/cogagent.py} +58 -95
- xinference/model/llm/transformers/multimodal/core.py +205 -0
- xinference/model/llm/transformers/{deepseek_vl2.py → multimodal/deepseek_vl2.py} +59 -120
- xinference/model/llm/transformers/multimodal/gemma3.py +117 -0
- xinference/model/llm/transformers/{glm4v.py → multimodal/glm4v.py} +57 -93
- xinference/model/llm/transformers/multimodal/intern_vl.py +412 -0
- xinference/model/llm/transformers/{minicpmv26.py → multimodal/minicpmv26.py} +55 -102
- xinference/model/llm/transformers/{ovis2.py → multimodal/ovis2.py} +114 -175
- xinference/model/llm/transformers/{qwen-omni.py → multimodal/qwen-omni.py} +82 -167
- xinference/model/llm/transformers/multimodal/qwen2_audio.py +131 -0
- xinference/model/llm/transformers/{qwen2_vl.py → multimodal/qwen2_vl.py} +224 -256
- xinference/model/llm/transformers/opt.py +4 -2
- xinference/model/llm/transformers/utils.py +6 -37
- xinference/model/llm/utils.py +11 -0
- xinference/model/llm/vllm/core.py +7 -0
- xinference/model/rerank/core.py +91 -3
- xinference/model/rerank/model_spec.json +24 -0
- xinference/model/rerank/model_spec_modelscope.json +24 -0
- xinference/model/rerank/utils.py +20 -2
- xinference/model/utils.py +38 -1
- xinference/model/video/diffusers.py +65 -3
- xinference/model/video/model_spec.json +31 -4
- xinference/model/video/model_spec_modelscope.json +32 -4
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.013f296b.css +2 -0
- xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/567e49df411efb24425d289bb484758cb57067ca54f8b5c67fe4505f698deb96.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
- xinference/web/ui/src/locales/en.json +21 -8
- xinference/web/ui/src/locales/ja.json +224 -0
- xinference/web/ui/src/locales/ko.json +224 -0
- xinference/web/ui/src/locales/zh.json +21 -8
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/METADATA +14 -11
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/RECORD +93 -100
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/WHEEL +1 -1
- xinference/model/llm/transformers/cogvlm2.py +0 -442
- xinference/model/llm/transformers/cogvlm2_video.py +0 -333
- xinference/model/llm/transformers/deepseek_vl.py +0 -280
- xinference/model/llm/transformers/glm_edge_v.py +0 -213
- xinference/model/llm/transformers/intern_vl.py +0 -526
- xinference/model/llm/transformers/internlm2.py +0 -94
- xinference/model/llm/transformers/minicpmv25.py +0 -193
- xinference/model/llm/transformers/omnilmm.py +0 -132
- xinference/model/llm/transformers/qwen2_audio.py +0 -179
- xinference/model/llm/transformers/qwen_vl.py +0 -360
- xinference/thirdparty/omnilmm/LICENSE +0 -201
- xinference/thirdparty/omnilmm/chat.py +0 -218
- xinference/thirdparty/omnilmm/constants.py +0 -4
- xinference/thirdparty/omnilmm/conversation.py +0 -332
- xinference/thirdparty/omnilmm/model/__init__.py +0 -1
- xinference/thirdparty/omnilmm/model/omnilmm.py +0 -595
- xinference/thirdparty/omnilmm/model/resampler.py +0 -166
- xinference/thirdparty/omnilmm/model/utils.py +0 -578
- xinference/thirdparty/omnilmm/train/train_utils.py +0 -150
- xinference/thirdparty/omnilmm/utils.py +0 -134
- xinference/web/ui/build/static/css/main.337afe76.css +0 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
- xinference/web/ui/build/static/js/main.ae579a97.js +0 -3
- xinference/web/ui/build/static/js/main.ae579a97.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
- /xinference/{thirdparty/omnilmm → model/embedding/vllm}/__init__.py +0 -0
- /xinference/web/ui/build/static/js/{main.ae579a97.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.0.post1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -16,16 +16,13 @@ import gc
|
|
|
16
16
|
import logging
|
|
17
17
|
import os
|
|
18
18
|
from collections import defaultdict
|
|
19
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
20
|
-
|
|
21
|
-
import numpy as np
|
|
22
|
-
import torch
|
|
19
|
+
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
23
20
|
|
|
24
21
|
from ..._compat import ROOT_KEY, ErrorWrapper, ValidationError
|
|
25
22
|
from ...device_utils import empty_cache
|
|
26
|
-
from ...types import Embedding, EmbeddingData, EmbeddingUsage
|
|
27
23
|
from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
|
|
28
24
|
from ..utils import get_cache_dir, is_model_cached
|
|
25
|
+
from .embed_family import match_embedding
|
|
29
26
|
|
|
30
27
|
logger = logging.getLogger(__name__)
|
|
31
28
|
|
|
@@ -49,6 +46,7 @@ def get_embedding_model_descriptions():
|
|
|
49
46
|
return copy.deepcopy(EMBEDDING_MODEL_DESCRIPTIONS)
|
|
50
47
|
|
|
51
48
|
|
|
49
|
+
# this class define the basic info of embedding model
|
|
52
50
|
class EmbeddingModelSpec(CacheableModelSpec):
|
|
53
51
|
model_name: str
|
|
54
52
|
dimensions: int
|
|
@@ -128,7 +126,11 @@ def get_cache_status(
|
|
|
128
126
|
return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
|
|
129
127
|
|
|
130
128
|
|
|
131
|
-
|
|
129
|
+
import abc
|
|
130
|
+
from abc import abstractmethod
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class EmbeddingModel(abc.ABC):
|
|
132
134
|
def __init__(
|
|
133
135
|
self,
|
|
134
136
|
model_uid: str,
|
|
@@ -141,95 +143,36 @@ class EmbeddingModel:
|
|
|
141
143
|
self._model_path = model_path
|
|
142
144
|
self._device = device
|
|
143
145
|
self._model = None
|
|
146
|
+
self._tokenizer = None
|
|
144
147
|
self._counter = 0
|
|
145
148
|
self._model_spec = model_spec
|
|
149
|
+
self._model_name = self._model_spec.model_name
|
|
146
150
|
self._kwargs = kwargs
|
|
147
151
|
|
|
152
|
+
@classmethod
|
|
153
|
+
@abstractmethod
|
|
154
|
+
def check_lib(cls) -> bool:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
@abstractmethod
|
|
159
|
+
def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def match(cls, model_spec: EmbeddingModelSpec):
|
|
164
|
+
"""
|
|
165
|
+
Return if the model_spec can be matched.
|
|
166
|
+
"""
|
|
167
|
+
if not cls.check_lib():
|
|
168
|
+
return False
|
|
169
|
+
return cls.match_json(model_spec)
|
|
170
|
+
|
|
171
|
+
@abstractmethod
|
|
148
172
|
def load(self):
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
if sentence_transformers.__version__ < "3.1.0":
|
|
154
|
-
raise ValueError(
|
|
155
|
-
"The sentence_transformers version must be greater than 3.1.0. "
|
|
156
|
-
"Please upgrade your version via `pip install -U sentence_transformers` or refer to "
|
|
157
|
-
"https://github.com/UKPLab/sentence-transformers"
|
|
158
|
-
)
|
|
159
|
-
except ImportError:
|
|
160
|
-
error_message = "Failed to import module 'SentenceTransformer'"
|
|
161
|
-
installation_guide = [
|
|
162
|
-
"Please make sure 'sentence-transformers' is installed. ",
|
|
163
|
-
"You can install it by `pip install sentence-transformers`\n",
|
|
164
|
-
]
|
|
165
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
166
|
-
|
|
167
|
-
class XSentenceTransformer(SentenceTransformer):
|
|
168
|
-
def to(self, *args, **kwargs):
|
|
169
|
-
pass
|
|
170
|
-
|
|
171
|
-
torch_dtype = None
|
|
172
|
-
if torch_dtype_str := self._kwargs.get("torch_dtype"):
|
|
173
|
-
try:
|
|
174
|
-
torch_dtype = getattr(torch, torch_dtype_str)
|
|
175
|
-
if torch_dtype not in [
|
|
176
|
-
torch.float16,
|
|
177
|
-
torch.float32,
|
|
178
|
-
torch.bfloat16,
|
|
179
|
-
]:
|
|
180
|
-
logger.warning(
|
|
181
|
-
f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
|
|
182
|
-
)
|
|
183
|
-
torch_dtype = torch.float32
|
|
184
|
-
except AttributeError:
|
|
185
|
-
logger.warning(
|
|
186
|
-
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
|
|
187
|
-
)
|
|
188
|
-
torch_dtype = torch.float32
|
|
189
|
-
|
|
190
|
-
if (
|
|
191
|
-
"gte" in self._model_spec.model_name.lower()
|
|
192
|
-
and "qwen2" in self._model_spec.model_name.lower()
|
|
193
|
-
):
|
|
194
|
-
model_kwargs = {"device_map": "auto"}
|
|
195
|
-
if torch_dtype:
|
|
196
|
-
model_kwargs["torch_dtype"] = torch_dtype
|
|
197
|
-
self._model = XSentenceTransformer(
|
|
198
|
-
self._model_path,
|
|
199
|
-
device=self._device,
|
|
200
|
-
model_kwargs=model_kwargs,
|
|
201
|
-
)
|
|
202
|
-
elif (
|
|
203
|
-
self._kwargs.get("hybrid_mode")
|
|
204
|
-
and "m3" in self._model_spec.model_name.lower()
|
|
205
|
-
):
|
|
206
|
-
try:
|
|
207
|
-
from FlagEmbedding import BGEM3FlagModel
|
|
208
|
-
except ImportError:
|
|
209
|
-
error_message = "Failed to import module 'BGEM3FlagModel'"
|
|
210
|
-
installation_guide = [
|
|
211
|
-
"Please make sure 'FlagEmbedding' is installed. ",
|
|
212
|
-
"You can install it by `pip install FlagEmbedding`\n",
|
|
213
|
-
]
|
|
214
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
215
|
-
|
|
216
|
-
if torch_dtype and torch_dtype == torch.float16:
|
|
217
|
-
model_kwargs = {"use_fp16": True}
|
|
218
|
-
else:
|
|
219
|
-
model_kwargs = {}
|
|
220
|
-
self._model = BGEM3FlagModel(
|
|
221
|
-
self._model_path,
|
|
222
|
-
device=self._device,
|
|
223
|
-
**model_kwargs,
|
|
224
|
-
)
|
|
225
|
-
else:
|
|
226
|
-
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
|
|
227
|
-
self._model = SentenceTransformer(
|
|
228
|
-
self._model_path,
|
|
229
|
-
device=self._device,
|
|
230
|
-
model_kwargs=model_kwargs,
|
|
231
|
-
trust_remote_code=True,
|
|
232
|
-
)
|
|
173
|
+
"""
|
|
174
|
+
Load embedding model
|
|
175
|
+
"""
|
|
233
176
|
|
|
234
177
|
def _fix_langchain_openai_inputs(
|
|
235
178
|
self, sentences: Union[str, List[str], Dict[str, str], List[Dict[str, str]]]
|
|
@@ -267,478 +210,52 @@ class EmbeddingModel:
|
|
|
267
210
|
sentences = lines_decoded
|
|
268
211
|
return sentences
|
|
269
212
|
|
|
213
|
+
@staticmethod
|
|
214
|
+
# copied from sentence-transformers
|
|
215
|
+
def _text_length(text):
|
|
216
|
+
if isinstance(text, dict): # {key: value} case
|
|
217
|
+
return len(next(iter(text.values())))
|
|
218
|
+
elif not hasattr(text, "__len__"): # Object has no len() method
|
|
219
|
+
return 1
|
|
220
|
+
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
|
|
221
|
+
return len(text)
|
|
222
|
+
else:
|
|
223
|
+
return sum([len(t) for t in text]) # Sum of length of individual strings
|
|
224
|
+
|
|
225
|
+
@abstractmethod
|
|
270
226
|
def create_embedding(
|
|
271
227
|
self,
|
|
272
228
|
sentences: Union[str, List[str]],
|
|
273
229
|
**kwargs,
|
|
274
230
|
):
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
from sentence_transformers import SentenceTransformer
|
|
278
|
-
|
|
279
|
-
kwargs.setdefault("normalize_embeddings", True)
|
|
280
|
-
|
|
281
|
-
try:
|
|
282
|
-
from FlagEmbedding import BGEM3FlagModel
|
|
283
|
-
|
|
284
|
-
@no_type_check
|
|
285
|
-
def _encode_bgem3(
|
|
286
|
-
model: Union[SentenceTransformer, BGEM3FlagModel],
|
|
287
|
-
sentences: Union[str, List[str]],
|
|
288
|
-
batch_size: int = 32,
|
|
289
|
-
show_progress_bar: bool = None,
|
|
290
|
-
output_value: str = "sparse_embedding",
|
|
291
|
-
convert_to_numpy: bool = True,
|
|
292
|
-
convert_to_tensor: bool = False,
|
|
293
|
-
device: str = None,
|
|
294
|
-
normalize_embeddings: bool = False,
|
|
295
|
-
**kwargs,
|
|
296
|
-
):
|
|
297
|
-
"""
|
|
298
|
-
Computes sentence embeddings with bge-m3 model
|
|
299
|
-
Nothing special here, just replace sentence-transformer with FlagEmbedding
|
|
300
|
-
TODO: think about how to solve the redundant code of encode method in the future
|
|
301
|
-
|
|
302
|
-
:param sentences: the sentences to embed
|
|
303
|
-
:param batch_size: the batch size used for the computation
|
|
304
|
-
:param show_progress_bar: Output a progress bar when encode sentences
|
|
305
|
-
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
|
|
306
|
-
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
|
|
307
|
-
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
|
|
308
|
-
:param device: Which torch.device to use for the computation
|
|
309
|
-
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
|
|
310
|
-
|
|
311
|
-
:return:
|
|
312
|
-
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
|
|
313
|
-
"""
|
|
314
|
-
import torch
|
|
315
|
-
from tqdm.autonotebook import trange
|
|
316
|
-
|
|
317
|
-
if show_progress_bar is None:
|
|
318
|
-
show_progress_bar = (
|
|
319
|
-
logger.getEffectiveLevel() == logging.INFO
|
|
320
|
-
or logger.getEffectiveLevel() == logging.DEBUG
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
if convert_to_tensor:
|
|
324
|
-
convert_to_numpy = False
|
|
325
|
-
|
|
326
|
-
if output_value != "sparse_embedding":
|
|
327
|
-
convert_to_tensor = False
|
|
328
|
-
convert_to_numpy = False
|
|
329
|
-
|
|
330
|
-
input_was_string = False
|
|
331
|
-
if isinstance(sentences, str) or not hasattr(
|
|
332
|
-
sentences, "__len__"
|
|
333
|
-
): # Cast an individual sentence to a list with length 1
|
|
334
|
-
sentences = [sentences]
|
|
335
|
-
input_was_string = True
|
|
336
|
-
|
|
337
|
-
if device is None:
|
|
338
|
-
# Same as SentenceTransformer.py
|
|
339
|
-
from sentence_transformers.util import get_device_name
|
|
340
|
-
|
|
341
|
-
device = get_device_name()
|
|
342
|
-
logger.info(f"Use pytorch device_name: {device}")
|
|
343
|
-
|
|
344
|
-
all_embeddings = []
|
|
345
|
-
all_token_nums = 0
|
|
346
|
-
|
|
347
|
-
# The original code does not support other inference engines
|
|
348
|
-
def _text_length(text):
|
|
349
|
-
if isinstance(text, dict): # {key: value} case
|
|
350
|
-
return len(next(iter(text.values())))
|
|
351
|
-
elif not hasattr(text, "__len__"): # Object has no len() method
|
|
352
|
-
return 1
|
|
353
|
-
elif len(text) == 0 or isinstance(
|
|
354
|
-
text[0], int
|
|
355
|
-
): # Empty string or list of ints
|
|
356
|
-
return len(text)
|
|
357
|
-
else:
|
|
358
|
-
return sum(
|
|
359
|
-
[len(t) for t in text]
|
|
360
|
-
) # Sum of length of individual strings
|
|
361
|
-
|
|
362
|
-
length_sorted_idx = np.argsort(
|
|
363
|
-
[-_text_length(sen) for sen in sentences]
|
|
364
|
-
)
|
|
365
|
-
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
|
366
|
-
|
|
367
|
-
for start_index in trange(
|
|
368
|
-
0,
|
|
369
|
-
len(sentences),
|
|
370
|
-
batch_size,
|
|
371
|
-
desc="Batches",
|
|
372
|
-
disable=not show_progress_bar,
|
|
373
|
-
):
|
|
374
|
-
sentences_batch = sentences_sorted[
|
|
375
|
-
start_index : start_index + batch_size
|
|
376
|
-
]
|
|
377
|
-
|
|
378
|
-
with torch.no_grad():
|
|
379
|
-
out_features = model.encode(sentences_batch, **kwargs)
|
|
380
|
-
|
|
381
|
-
if output_value == "token_embeddings":
|
|
382
|
-
embeddings = []
|
|
383
|
-
for token_emb, attention in zip(
|
|
384
|
-
out_features[output_value],
|
|
385
|
-
out_features["attention_mask"],
|
|
386
|
-
):
|
|
387
|
-
last_mask_id = len(attention) - 1
|
|
388
|
-
while (
|
|
389
|
-
last_mask_id > 0
|
|
390
|
-
and attention[last_mask_id].item() == 0
|
|
391
|
-
):
|
|
392
|
-
last_mask_id -= 1
|
|
393
|
-
|
|
394
|
-
embeddings.append(token_emb[0 : last_mask_id + 1])
|
|
395
|
-
elif output_value is None: # Return all outputs
|
|
396
|
-
embeddings = []
|
|
397
|
-
for sent_idx in range(
|
|
398
|
-
len(out_features["sentence_embedding"])
|
|
399
|
-
):
|
|
400
|
-
row = {
|
|
401
|
-
name: out_features[name][sent_idx]
|
|
402
|
-
for name in out_features
|
|
403
|
-
}
|
|
404
|
-
embeddings.append(row)
|
|
405
|
-
# for sparse embedding
|
|
406
|
-
else:
|
|
407
|
-
if kwargs.get("return_sparse"):
|
|
408
|
-
embeddings = out_features["lexical_weights"]
|
|
409
|
-
else:
|
|
410
|
-
embeddings = out_features["dense_vecs"]
|
|
411
|
-
|
|
412
|
-
if convert_to_numpy:
|
|
413
|
-
embeddings = embeddings.cpu()
|
|
414
|
-
|
|
415
|
-
all_embeddings.extend(embeddings)
|
|
416
|
-
|
|
417
|
-
all_embeddings = [
|
|
418
|
-
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
|
419
|
-
]
|
|
420
|
-
|
|
421
|
-
if convert_to_tensor:
|
|
422
|
-
if len(all_embeddings):
|
|
423
|
-
all_embeddings = torch.stack(all_embeddings)
|
|
424
|
-
else:
|
|
425
|
-
all_embeddings = torch.Tensor()
|
|
426
|
-
elif convert_to_numpy:
|
|
427
|
-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
|
428
|
-
|
|
429
|
-
if input_was_string:
|
|
430
|
-
all_embeddings = all_embeddings[0]
|
|
431
|
-
|
|
432
|
-
return all_embeddings, all_token_nums
|
|
433
|
-
|
|
434
|
-
except ImportError:
|
|
435
|
-
_encode_bgem3 = None
|
|
436
|
-
|
|
437
|
-
# copied from sentence-transformers, and modify it to return tokens num
|
|
438
|
-
@no_type_check
|
|
439
|
-
def encode(
|
|
440
|
-
model: SentenceTransformer,
|
|
441
|
-
sentences: Union[str, List[str]],
|
|
442
|
-
prompt_name: Optional[str] = None,
|
|
443
|
-
prompt: Optional[str] = None,
|
|
444
|
-
batch_size: int = 32,
|
|
445
|
-
show_progress_bar: bool = None,
|
|
446
|
-
output_value: str = "sentence_embedding",
|
|
447
|
-
convert_to_numpy: bool = True,
|
|
448
|
-
convert_to_tensor: bool = False,
|
|
449
|
-
device: str = None,
|
|
450
|
-
normalize_embeddings: bool = False,
|
|
451
|
-
**kwargs,
|
|
452
|
-
):
|
|
453
|
-
"""
|
|
454
|
-
Computes sentence embeddings
|
|
455
|
-
|
|
456
|
-
:param sentences: the sentences to embed
|
|
457
|
-
:param batch_size: the batch size used for the computation
|
|
458
|
-
:param show_progress_bar: Output a progress bar when encode sentences
|
|
459
|
-
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
|
|
460
|
-
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
|
|
461
|
-
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
|
|
462
|
-
:param device: Which torch.device to use for the computation
|
|
463
|
-
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
|
|
464
|
-
|
|
465
|
-
:return:
|
|
466
|
-
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
|
|
467
|
-
"""
|
|
468
|
-
import torch
|
|
469
|
-
from sentence_transformers.util import batch_to_device
|
|
470
|
-
from tqdm.autonotebook import trange
|
|
471
|
-
|
|
472
|
-
model.eval()
|
|
473
|
-
if show_progress_bar is None:
|
|
474
|
-
show_progress_bar = (
|
|
475
|
-
logger.getEffectiveLevel() == logging.INFO
|
|
476
|
-
or logger.getEffectiveLevel() == logging.DEBUG
|
|
477
|
-
)
|
|
478
|
-
|
|
479
|
-
if convert_to_tensor:
|
|
480
|
-
convert_to_numpy = False
|
|
481
|
-
|
|
482
|
-
if output_value != "sentence_embedding":
|
|
483
|
-
convert_to_tensor = False
|
|
484
|
-
convert_to_numpy = False
|
|
485
|
-
|
|
486
|
-
input_was_string = False
|
|
487
|
-
if isinstance(sentences, str) or not hasattr(
|
|
488
|
-
sentences, "__len__"
|
|
489
|
-
): # Cast an individual sentence to a list with length 1
|
|
490
|
-
sentences = [sentences]
|
|
491
|
-
input_was_string = True
|
|
492
|
-
|
|
493
|
-
if prompt is None:
|
|
494
|
-
if prompt_name is not None:
|
|
495
|
-
try:
|
|
496
|
-
prompt = model.prompts[prompt_name]
|
|
497
|
-
except KeyError:
|
|
498
|
-
raise ValueError(
|
|
499
|
-
f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(model.prompts.keys())!r}."
|
|
500
|
-
)
|
|
501
|
-
elif model.default_prompt_name is not None:
|
|
502
|
-
prompt = model.prompts.get(model.default_prompt_name, None)
|
|
503
|
-
else:
|
|
504
|
-
if prompt_name is not None:
|
|
505
|
-
logger.warning(
|
|
506
|
-
"Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
|
|
507
|
-
"Ignoring the `prompt_name` in favor of `prompt`."
|
|
508
|
-
)
|
|
509
|
-
|
|
510
|
-
extra_features = {}
|
|
511
|
-
if prompt is not None:
|
|
512
|
-
sentences = [prompt + sentence for sentence in sentences]
|
|
513
|
-
|
|
514
|
-
# Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
|
|
515
|
-
# Tracking the prompt length allow us to remove the prompt during pooling
|
|
516
|
-
tokenized_prompt = model.tokenize([prompt])
|
|
517
|
-
if "input_ids" in tokenized_prompt:
|
|
518
|
-
extra_features["prompt_length"] = (
|
|
519
|
-
tokenized_prompt["input_ids"].shape[-1] - 1
|
|
520
|
-
)
|
|
521
|
-
|
|
522
|
-
if device is None:
|
|
523
|
-
device = model._target_device
|
|
524
|
-
|
|
525
|
-
if (
|
|
526
|
-
"gte" in self._model_spec.model_name.lower()
|
|
527
|
-
and "qwen2" in self._model_spec.model_name.lower()
|
|
528
|
-
):
|
|
529
|
-
model.to(device)
|
|
530
|
-
|
|
531
|
-
all_embeddings = []
|
|
532
|
-
all_token_nums = 0
|
|
533
|
-
length_sorted_idx = np.argsort(
|
|
534
|
-
[-model._text_length(sen) for sen in sentences]
|
|
535
|
-
)
|
|
536
|
-
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
|
537
|
-
|
|
538
|
-
for start_index in trange(
|
|
539
|
-
0,
|
|
540
|
-
len(sentences),
|
|
541
|
-
batch_size,
|
|
542
|
-
desc="Batches",
|
|
543
|
-
disable=not show_progress_bar,
|
|
544
|
-
):
|
|
545
|
-
sentences_batch = sentences_sorted[
|
|
546
|
-
start_index : start_index + batch_size
|
|
547
|
-
]
|
|
548
|
-
features = model.tokenize(sentences_batch)
|
|
549
|
-
features = batch_to_device(features, device)
|
|
550
|
-
features.update(extra_features)
|
|
551
|
-
# when batching, the attention mask 1 means there is a token
|
|
552
|
-
# thus we just sum up it to get the total number of tokens
|
|
553
|
-
if "clip" in self._model_spec.model_name.lower():
|
|
554
|
-
if "input_ids" in features and hasattr(
|
|
555
|
-
features["input_ids"], "numel"
|
|
556
|
-
):
|
|
557
|
-
all_token_nums += features["input_ids"].numel()
|
|
558
|
-
if "pixel_values" in features and hasattr(
|
|
559
|
-
features["pixel_values"], "numel"
|
|
560
|
-
):
|
|
561
|
-
all_token_nums += features["pixel_values"].numel()
|
|
562
|
-
else:
|
|
563
|
-
all_token_nums += features["attention_mask"].sum().item()
|
|
564
|
-
|
|
565
|
-
with torch.no_grad():
|
|
566
|
-
out_features = model.forward(features, **kwargs)
|
|
567
|
-
|
|
568
|
-
if output_value == "token_embeddings":
|
|
569
|
-
embeddings = []
|
|
570
|
-
for token_emb, attention in zip(
|
|
571
|
-
out_features[output_value], out_features["attention_mask"]
|
|
572
|
-
):
|
|
573
|
-
last_mask_id = len(attention) - 1
|
|
574
|
-
while (
|
|
575
|
-
last_mask_id > 0 and attention[last_mask_id].item() == 0
|
|
576
|
-
):
|
|
577
|
-
last_mask_id -= 1
|
|
578
|
-
|
|
579
|
-
embeddings.append(token_emb[0 : last_mask_id + 1])
|
|
580
|
-
elif output_value is None: # Return all outputs
|
|
581
|
-
embeddings = []
|
|
582
|
-
for sent_idx in range(len(out_features["sentence_embedding"])):
|
|
583
|
-
row = {
|
|
584
|
-
name: out_features[name][sent_idx]
|
|
585
|
-
for name in out_features
|
|
586
|
-
}
|
|
587
|
-
embeddings.append(row)
|
|
588
|
-
else: # Sentence embeddings
|
|
589
|
-
embeddings = out_features[output_value]
|
|
590
|
-
embeddings = embeddings.detach()
|
|
591
|
-
if normalize_embeddings:
|
|
592
|
-
embeddings = torch.nn.functional.normalize(
|
|
593
|
-
embeddings, p=2, dim=1
|
|
594
|
-
)
|
|
595
|
-
|
|
596
|
-
# fixes for #522 and #487 to avoid oom problems on gpu with large datasets
|
|
597
|
-
if convert_to_numpy:
|
|
598
|
-
embeddings = embeddings.cpu()
|
|
599
|
-
|
|
600
|
-
all_embeddings.extend(embeddings)
|
|
601
|
-
|
|
602
|
-
all_embeddings = [
|
|
603
|
-
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
|
604
|
-
]
|
|
605
|
-
|
|
606
|
-
if convert_to_tensor:
|
|
607
|
-
if len(all_embeddings):
|
|
608
|
-
all_embeddings = torch.stack(all_embeddings)
|
|
609
|
-
else:
|
|
610
|
-
all_embeddings = torch.Tensor()
|
|
611
|
-
elif convert_to_numpy:
|
|
612
|
-
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
|
613
|
-
|
|
614
|
-
if input_was_string:
|
|
615
|
-
all_embeddings = all_embeddings[0]
|
|
616
|
-
|
|
617
|
-
return all_embeddings, all_token_nums
|
|
618
|
-
|
|
619
|
-
is_bge_m3_flag_model = (
|
|
620
|
-
self._kwargs.get("hybrid_mode")
|
|
621
|
-
and "m3" in self._model_spec.model_name.lower()
|
|
622
|
-
)
|
|
623
|
-
if (
|
|
624
|
-
"gte" in self._model_spec.model_name.lower()
|
|
625
|
-
and "qwen2" in self._model_spec.model_name.lower()
|
|
626
|
-
):
|
|
627
|
-
all_embeddings, all_token_nums = encode(
|
|
628
|
-
self._model,
|
|
629
|
-
sentences,
|
|
630
|
-
prompt_name="query",
|
|
631
|
-
convert_to_numpy=False,
|
|
632
|
-
**kwargs,
|
|
633
|
-
)
|
|
634
|
-
elif is_bge_m3_flag_model:
|
|
635
|
-
assert _encode_bgem3 is not None
|
|
636
|
-
all_embeddings, all_token_nums = _encode_bgem3(
|
|
637
|
-
self._model, sentences, convert_to_numpy=False, **kwargs
|
|
638
|
-
)
|
|
639
|
-
elif "clip" in self._model_spec.model_name.lower():
|
|
640
|
-
import base64
|
|
641
|
-
import re
|
|
642
|
-
from io import BytesIO
|
|
643
|
-
|
|
644
|
-
from PIL import Image
|
|
645
|
-
|
|
646
|
-
def base64_to_image(base64_str: str) -> Image.Image:
|
|
647
|
-
# base64_data = re.sub("^data:image/.+;base64,", "", base64_str)
|
|
648
|
-
base64_data = base64_str.split(",", 1)[1]
|
|
649
|
-
byte_data = base64.b64decode(base64_data)
|
|
650
|
-
image_data = BytesIO(byte_data)
|
|
651
|
-
img = Image.open(image_data)
|
|
652
|
-
return img
|
|
653
|
-
|
|
654
|
-
objs: list[dict[str, str]] = []
|
|
655
|
-
for item in sentences:
|
|
656
|
-
if isinstance(item, dict):
|
|
657
|
-
if item.get("text") is not None:
|
|
658
|
-
objs.append(item["text"])
|
|
659
|
-
elif item.get("image") is not None:
|
|
660
|
-
if re.match(r"^data:image/.+;base64,", item["image"]):
|
|
661
|
-
image = base64_to_image(item["image"])
|
|
662
|
-
objs.append(image)
|
|
663
|
-
else:
|
|
664
|
-
objs.append(item["image"])
|
|
665
|
-
else:
|
|
666
|
-
logger.error("Please check the input data.")
|
|
667
|
-
all_embeddings, all_token_nums = encode(
|
|
668
|
-
self._model,
|
|
669
|
-
objs,
|
|
670
|
-
convert_to_numpy=False,
|
|
671
|
-
**kwargs,
|
|
672
|
-
)
|
|
673
|
-
else:
|
|
674
|
-
all_embeddings, all_token_nums = encode(
|
|
675
|
-
self._model,
|
|
676
|
-
sentences,
|
|
677
|
-
convert_to_numpy=False,
|
|
678
|
-
**kwargs,
|
|
679
|
-
)
|
|
680
|
-
if isinstance(sentences, str):
|
|
681
|
-
all_embeddings = [all_embeddings]
|
|
682
|
-
embedding_list = []
|
|
683
|
-
for index, data in enumerate(all_embeddings):
|
|
684
|
-
if kwargs.get("return_sparse") and is_bge_m3_flag_model:
|
|
685
|
-
embedding_list.append(
|
|
686
|
-
EmbeddingData(
|
|
687
|
-
index=index,
|
|
688
|
-
object="embedding",
|
|
689
|
-
embedding={k: float(v) for k, v in data.items()},
|
|
690
|
-
)
|
|
691
|
-
)
|
|
692
|
-
else:
|
|
693
|
-
embedding_list.append(
|
|
694
|
-
EmbeddingData(
|
|
695
|
-
index=index, object="embedding", embedding=data.tolist()
|
|
696
|
-
)
|
|
697
|
-
)
|
|
698
|
-
usage = EmbeddingUsage(
|
|
699
|
-
prompt_tokens=all_token_nums, total_tokens=all_token_nums
|
|
700
|
-
)
|
|
701
|
-
result = Embedding(
|
|
702
|
-
object=(
|
|
703
|
-
"list" # type: ignore
|
|
704
|
-
if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
|
|
705
|
-
else "dict"
|
|
706
|
-
),
|
|
707
|
-
model=model_uid, # type: ignore
|
|
708
|
-
model_replica=self._model_uid,
|
|
709
|
-
data=embedding_list,
|
|
710
|
-
usage=usage,
|
|
711
|
-
)
|
|
231
|
+
"""
|
|
232
|
+
Creating embeddings from sentences.
|
|
712
233
|
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
or
|
|
718
|
-
):
|
|
719
|
-
logger.debug(
|
|
720
|
-
"Empty embedding cache, calling count %s, all_token_nums %s",
|
|
721
|
-
self._counter,
|
|
722
|
-
all_token_nums,
|
|
723
|
-
)
|
|
724
|
-
gc.collect()
|
|
725
|
-
empty_cache()
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
sentences: Union[str, List[str]]
|
|
237
|
+
Input text to embed, encoded as a string or array of tokens.
|
|
238
|
+
To embed multiple inputs in a single request, pass an array of strings or array of token arrays.
|
|
726
239
|
|
|
727
|
-
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
Embedding
|
|
243
|
+
The resulted Embedding vector that can be easily consumed by machine learning models and algorithms.
|
|
244
|
+
"""
|
|
728
245
|
|
|
729
246
|
def convert_ids_to_tokens(
|
|
730
247
|
self,
|
|
731
248
|
batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
|
|
732
249
|
**kwargs,
|
|
733
250
|
) -> Union[List[str]]:
|
|
734
|
-
|
|
735
|
-
|
|
251
|
+
"""
|
|
252
|
+
Convert token ids to tokens
|
|
253
|
+
"""
|
|
736
254
|
assert self._model is not None
|
|
737
|
-
|
|
738
255
|
if isinstance(batch_token_ids, (int, str)):
|
|
739
|
-
return self.
|
|
740
|
-
|
|
741
|
-
|
|
256
|
+
return self._tokenizer.decode([int(str(batch_token_ids))])[0]
|
|
257
|
+
|
|
258
|
+
batch_decoded_texts: List[str] = []
|
|
742
259
|
|
|
743
260
|
# check if it's a nested list
|
|
744
261
|
if (
|
|
@@ -747,58 +264,37 @@ class EmbeddingModel:
|
|
|
747
264
|
and isinstance(batch_token_ids[0], list)
|
|
748
265
|
):
|
|
749
266
|
for token_ids in batch_token_ids:
|
|
750
|
-
token_ids = [int(token_id) for token_id in token_ids]
|
|
267
|
+
token_ids = [int(token_id) for token_id in token_ids] # type: ignore
|
|
751
268
|
batch_decoded_texts.append(
|
|
752
|
-
self.
|
|
269
|
+
self._tokenizer.convert_ids_to_tokens(token_ids)
|
|
753
270
|
)
|
|
754
271
|
else:
|
|
755
|
-
batch_token_ids = [int(token_id) for token_id in batch_token_ids]
|
|
756
|
-
batch_decoded_texts = self.
|
|
757
|
-
batch_token_ids
|
|
758
|
-
)
|
|
272
|
+
batch_token_ids = [int(token_id) for token_id in batch_token_ids] # type: ignore
|
|
273
|
+
batch_decoded_texts = self._tokenizer.convert_ids_to_tokens(batch_token_ids)
|
|
759
274
|
return batch_decoded_texts
|
|
760
275
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
)
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
return model_spec
|
|
776
|
-
|
|
777
|
-
if download_hub == "modelscope" and model_name in MODELSCOPE_EMBEDDING_MODELS:
|
|
778
|
-
logger.debug(f"Embedding model {model_name} found in ModelScope.")
|
|
779
|
-
return MODELSCOPE_EMBEDDING_MODELS[model_name]
|
|
780
|
-
elif download_hub == "huggingface" and model_name in BUILTIN_EMBEDDING_MODELS:
|
|
781
|
-
logger.debug(f"Embedding model {model_name} found in Huggingface.")
|
|
782
|
-
return BUILTIN_EMBEDDING_MODELS[model_name]
|
|
783
|
-
elif download_from_modelscope() and model_name in MODELSCOPE_EMBEDDING_MODELS:
|
|
784
|
-
logger.debug(f"Embedding model {model_name} found in ModelScope.")
|
|
785
|
-
return MODELSCOPE_EMBEDDING_MODELS[model_name]
|
|
786
|
-
elif model_name in BUILTIN_EMBEDDING_MODELS:
|
|
787
|
-
logger.debug(f"Embedding model {model_name} found in Huggingface.")
|
|
788
|
-
return BUILTIN_EMBEDDING_MODELS[model_name]
|
|
789
|
-
else:
|
|
790
|
-
raise ValueError(
|
|
791
|
-
f"Embedding model {model_name} not found, available"
|
|
792
|
-
f"Huggingface: {BUILTIN_EMBEDDING_MODELS.keys()}"
|
|
793
|
-
f"ModelScope: {MODELSCOPE_EMBEDDING_MODELS.keys()}"
|
|
794
|
-
)
|
|
276
|
+
def _clean_cache_if_needed(self, all_token_nums: int):
|
|
277
|
+
# clean cache if possible
|
|
278
|
+
self._counter += 1
|
|
279
|
+
if (
|
|
280
|
+
self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
|
|
281
|
+
or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
|
|
282
|
+
):
|
|
283
|
+
logger.debug(
|
|
284
|
+
"Empty embedding cache, calling count %s, all_token_nums %s",
|
|
285
|
+
self._counter,
|
|
286
|
+
all_token_nums,
|
|
287
|
+
)
|
|
288
|
+
gc.collect()
|
|
289
|
+
empty_cache()
|
|
795
290
|
|
|
796
291
|
|
|
797
292
|
def create_embedding_model_instance(
|
|
798
293
|
subpool_addr: str,
|
|
799
|
-
devices: List[str],
|
|
294
|
+
devices: Optional[List[str]],
|
|
800
295
|
model_uid: str,
|
|
801
296
|
model_name: str,
|
|
297
|
+
model_engine: Optional[str],
|
|
802
298
|
download_hub: Optional[
|
|
803
299
|
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
804
300
|
] = None,
|
|
@@ -809,7 +305,20 @@ def create_embedding_model_instance(
|
|
|
809
305
|
if model_path is None:
|
|
810
306
|
model_path = cache(model_spec)
|
|
811
307
|
|
|
812
|
-
|
|
308
|
+
if model_engine is None:
|
|
309
|
+
# unlike LLM and for compatibility
|
|
310
|
+
# we use sentence_transformers as the default engine for all models
|
|
311
|
+
model_engine = "sentence_transformers"
|
|
312
|
+
|
|
313
|
+
from .embed_family import check_engine_by_model_name_and_engine
|
|
314
|
+
|
|
315
|
+
embedding_cls = check_engine_by_model_name_and_engine(
|
|
316
|
+
model_name,
|
|
317
|
+
model_engine,
|
|
318
|
+
)
|
|
319
|
+
devices = devices or ["cpu"]
|
|
320
|
+
# model class should be one of flag, fastembed, sentence_transformers
|
|
321
|
+
model = embedding_cls(model_uid, model_path, model_spec, **kwargs)
|
|
813
322
|
model_description = EmbeddingModelDescription(
|
|
814
323
|
subpool_addr, devices, model_spec, model_path=model_path
|
|
815
324
|
)
|