xinference 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +79 -2
- xinference/client/restful/restful_client.py +64 -2
- xinference/core/media_interface.py +123 -0
- xinference/core/model.py +31 -0
- xinference/core/supervisor.py +8 -17
- xinference/core/worker.py +5 -17
- xinference/deploy/cmdline.py +6 -2
- xinference/model/audio/chattts.py +24 -39
- xinference/model/audio/cosyvoice.py +18 -30
- xinference/model/audio/funasr.py +42 -0
- xinference/model/audio/model_spec.json +18 -0
- xinference/model/audio/model_spec_modelscope.json +19 -1
- xinference/model/audio/utils.py +75 -0
- xinference/model/core.py +1 -0
- xinference/model/embedding/__init__.py +74 -18
- xinference/model/embedding/core.py +98 -597
- xinference/model/embedding/embed_family.py +133 -0
- xinference/model/embedding/flag/__init__.py +13 -0
- xinference/model/embedding/flag/core.py +282 -0
- xinference/model/embedding/model_spec.json +24 -0
- xinference/model/embedding/model_spec_modelscope.json +24 -0
- xinference/model/embedding/sentence_transformers/__init__.py +13 -0
- xinference/model/embedding/sentence_transformers/core.py +399 -0
- xinference/model/embedding/vllm/__init__.py +0 -0
- xinference/model/embedding/vllm/core.py +95 -0
- xinference/model/image/model_spec.json +20 -2
- xinference/model/image/model_spec_modelscope.json +21 -2
- xinference/model/image/stable_diffusion/core.py +144 -53
- xinference/model/llm/llama_cpp/memory.py +4 -2
- xinference/model/llm/llm_family.json +57 -0
- xinference/model/llm/llm_family_modelscope.json +61 -0
- xinference/model/llm/sglang/core.py +4 -0
- xinference/model/llm/utils.py +11 -0
- xinference/model/llm/vllm/core.py +3 -0
- xinference/model/rerank/core.py +86 -4
- xinference/model/rerank/model_spec.json +24 -0
- xinference/model/rerank/model_spec_modelscope.json +24 -0
- xinference/model/rerank/utils.py +4 -3
- xinference/model/utils.py +38 -1
- xinference/model/video/diffusers.py +65 -3
- xinference/model/video/model_spec.json +31 -4
- xinference/model/video/model_spec_modelscope.json +32 -4
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.013f296b.css +2 -0
- xinference/web/ui/build/static/css/main.013f296b.css.map +1 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js +3 -0
- xinference/web/ui/build/static/js/main.8a9e3ba0.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/34cfbfb7836e136ba3261cfd411cc554bf99ba24b35dcceebeaa4f008cb3c9dc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/55b9fb40b57fa926e8f05f31c2f96467e76e5ad62f033dca97c03f9e8c4eb4fe.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6595880facebca7ceace6f17cf21c3a5a9219a2f52fb0ba9f3cf1131eddbcf6b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/aa998bc2d9c11853add6b8a2e08f50327f56d8824ccaaec92d6dde1b305f0d85.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c748246b1d7bcebc16153be69f37e955bb2145526c47dd425aeeff70d3004dbc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e31234e95d60a5a7883fbcd70de2475dc1c88c90705df1a530abb68f86f80a51.json +1 -0
- xinference/web/ui/src/locales/en.json +18 -7
- xinference/web/ui/src/locales/ja.json +224 -0
- xinference/web/ui/src/locales/ko.json +224 -0
- xinference/web/ui/src/locales/zh.json +18 -7
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/METADATA +9 -8
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/RECORD +66 -57
- xinference/web/ui/build/static/css/main.337afe76.css +0 -2
- xinference/web/ui/build/static/css/main.337afe76.css.map +0 -1
- xinference/web/ui/build/static/js/main.ddf9eaee.js +0 -3
- xinference/web/ui/build/static/js/main.ddf9eaee.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/12e637ed5fa9ca6491b03892b6949c03afd4960fe36ac25744488e7e1982aa19.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/77ac2665a784e99501ae95d32ef5937837a0439a47e965d291b38e99cb619f5b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d4ed4e82bfe69915999ec83f5feaa4301c75ecc6bdf1c78f2d03e4671ecbefc8.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +0 -1
- /xinference/web/ui/build/static/js/{main.ddf9eaee.js.LICENSE.txt → main.8a9e3ba0.js.LICENSE.txt} +0 -0
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/WHEEL +0 -0
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.6.1.dist-info → xinference-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from threading import Lock
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Type
|
|
18
|
+
|
|
19
|
+
from ..utils import is_valid_model_name
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from .core import EmbeddingModel, EmbeddingModelSpec
|
|
23
|
+
|
|
24
|
+
FLAG_EMBEDDER_CLASSES: List[Type["EmbeddingModel"]] = []
|
|
25
|
+
SENTENCE_TRANSFORMER_CLASSES: List[Type["EmbeddingModel"]] = []
|
|
26
|
+
VLLM_CLASSES: List[Type["EmbeddingModel"]] = []
|
|
27
|
+
|
|
28
|
+
BUILTIN_EMBEDDING_MODELS: Dict[str, Any] = {}
|
|
29
|
+
MODELSCOPE_EMBEDDING_MODELS: Dict[str, Any] = {}
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Desc: this file used to manage embedding models information.
|
|
35
|
+
def match_embedding(
|
|
36
|
+
model_name: str,
|
|
37
|
+
download_hub: Optional[
|
|
38
|
+
Literal["huggingface", "modelscope", "openmind_hub", "csghub"]
|
|
39
|
+
] = None,
|
|
40
|
+
) -> "EmbeddingModelSpec":
|
|
41
|
+
from ..utils import download_from_modelscope
|
|
42
|
+
|
|
43
|
+
# The model info has benn init by __init__.py with model_spec.json file
|
|
44
|
+
from .custom import get_user_defined_embeddings
|
|
45
|
+
|
|
46
|
+
# first, check whether it is a user-defined embedding model
|
|
47
|
+
for model_spec in get_user_defined_embeddings():
|
|
48
|
+
if model_name == model_spec.model_name:
|
|
49
|
+
return model_spec
|
|
50
|
+
|
|
51
|
+
if download_hub == "modelscope" and model_name in MODELSCOPE_EMBEDDING_MODELS:
|
|
52
|
+
logger.debug(f"Embedding model {model_name} found in ModelScope.")
|
|
53
|
+
return MODELSCOPE_EMBEDDING_MODELS[model_name]
|
|
54
|
+
elif download_hub == "huggingface" and model_name in BUILTIN_EMBEDDING_MODELS:
|
|
55
|
+
logger.debug(f"Embedding model {model_name} found in Huggingface.")
|
|
56
|
+
return BUILTIN_EMBEDDING_MODELS[model_name]
|
|
57
|
+
elif download_from_modelscope() and model_name in MODELSCOPE_EMBEDDING_MODELS:
|
|
58
|
+
logger.debug(f"Embedding model {model_name} found in ModelScope.")
|
|
59
|
+
return MODELSCOPE_EMBEDDING_MODELS[model_name]
|
|
60
|
+
elif model_name in BUILTIN_EMBEDDING_MODELS:
|
|
61
|
+
logger.debug(f"Embedding model {model_name} found in Huggingface.")
|
|
62
|
+
return BUILTIN_EMBEDDING_MODELS[model_name]
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Embedding model {model_name} not found, available"
|
|
66
|
+
f"Huggingface: {BUILTIN_EMBEDDING_MODELS.keys()}"
|
|
67
|
+
f"ModelScope: {MODELSCOPE_EMBEDDING_MODELS.keys()}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# { embedding model name -> { engine name -> engine params } }
|
|
72
|
+
EMBEDDING_ENGINES: Dict[str, Dict[str, List[Dict[str, Type["EmbeddingModel"]]]]] = {}
|
|
73
|
+
SUPPORTED_ENGINES: Dict[str, List[Type["EmbeddingModel"]]] = {}
|
|
74
|
+
UD_EMBEDDING_FAMILIES_LOCK = Lock()
|
|
75
|
+
# user defined embedding models
|
|
76
|
+
UD_EMBEDDING_SPECS: Dict[str, "EmbeddingModelSpec"] = {}
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def register_embedding(custom_embedding_spec: "EmbeddingModelSpec", persist: bool):
|
|
80
|
+
from ..utils import is_valid_model_uri
|
|
81
|
+
from . import generate_engine_config_by_model_name
|
|
82
|
+
|
|
83
|
+
if not is_valid_model_name(custom_embedding_spec.model_name):
|
|
84
|
+
raise ValueError(f"Invalid model name {custom_embedding_spec.model_name}.")
|
|
85
|
+
|
|
86
|
+
model_uri = custom_embedding_spec.model_uri
|
|
87
|
+
if model_uri and not is_valid_model_uri(model_uri):
|
|
88
|
+
raise ValueError(f"Invalid model URI {model_uri}.")
|
|
89
|
+
|
|
90
|
+
with UD_EMBEDDING_FAMILIES_LOCK:
|
|
91
|
+
if (
|
|
92
|
+
custom_embedding_spec.model_name in BUILTIN_EMBEDDING_MODELS
|
|
93
|
+
or custom_embedding_spec.model_name in MODELSCOPE_EMBEDDING_MODELS
|
|
94
|
+
or custom_embedding_spec.model_name in UD_EMBEDDING_SPECS
|
|
95
|
+
):
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"Model name conflicts with existing model {custom_embedding_spec.model_name}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
UD_EMBEDDING_SPECS[custom_embedding_spec.model_name] = custom_embedding_spec
|
|
101
|
+
generate_engine_config_by_model_name(custom_embedding_spec)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# TODO: add persist feature
|
|
105
|
+
def unregister_embedding(custom_embedding_spec: "EmbeddingModelSpec"):
|
|
106
|
+
with UD_EMBEDDING_FAMILIES_LOCK:
|
|
107
|
+
model_name = custom_embedding_spec.model_name
|
|
108
|
+
if model_name in UD_EMBEDDING_SPECS:
|
|
109
|
+
del UD_EMBEDDING_SPECS[model_name]
|
|
110
|
+
if model_name in EMBEDDING_ENGINES:
|
|
111
|
+
del EMBEDDING_ENGINES[model_name]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def check_engine_by_model_name_and_engine(
|
|
115
|
+
model_name: str,
|
|
116
|
+
model_engine: str,
|
|
117
|
+
) -> Type["EmbeddingModel"]:
|
|
118
|
+
def get_model_engine_from_spell(engine_str: str) -> str:
|
|
119
|
+
for engine in EMBEDDING_ENGINES[model_name].keys():
|
|
120
|
+
if engine.lower() == engine_str.lower():
|
|
121
|
+
return engine
|
|
122
|
+
return engine_str
|
|
123
|
+
|
|
124
|
+
if model_name not in EMBEDDING_ENGINES:
|
|
125
|
+
raise ValueError(f"Model {model_name} not found.")
|
|
126
|
+
model_engine = get_model_engine_from_spell(model_engine)
|
|
127
|
+
if model_engine not in EMBEDDING_ENGINES[model_name]:
|
|
128
|
+
raise ValueError(f"Model {model_name} cannot be run on engine {model_engine}.")
|
|
129
|
+
match_params = EMBEDDING_ENGINES[model_name][model_engine]
|
|
130
|
+
for param in match_params:
|
|
131
|
+
if model_name == param["model_name"]:
|
|
132
|
+
return param["embedding_class"]
|
|
133
|
+
raise ValueError(f"Model {model_name} cannot be run on engine {model_engine}.")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import importlib.util
|
|
16
|
+
import logging
|
|
17
|
+
from typing import List, Optional, Union, no_type_check
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from FlagEmbedding.inference.embedder.model_mapping import (
|
|
24
|
+
support_native_bge_model_list,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
flag_installed = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
flag_installed = False
|
|
30
|
+
|
|
31
|
+
from ....device_utils import get_available_device
|
|
32
|
+
from ....types import Embedding, EmbeddingData, EmbeddingUsage
|
|
33
|
+
from ..core import EmbeddingModel, EmbeddingModelSpec
|
|
34
|
+
|
|
35
|
+
FLAG_EMBEDDER_MODEL_LIST = support_native_bge_model_list() if flag_installed else []
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class FlagEmbeddingModel(EmbeddingModel):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
model_uid: str,
|
|
43
|
+
model_path: str,
|
|
44
|
+
model_spec: EmbeddingModelSpec,
|
|
45
|
+
device: Optional[str] = None,
|
|
46
|
+
return_sparse: bool = False,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(model_uid, model_path, model_spec, device, **kwargs)
|
|
50
|
+
self._return_sparse = return_sparse
|
|
51
|
+
|
|
52
|
+
def load(self):
|
|
53
|
+
try:
|
|
54
|
+
from FlagEmbedding import BGEM3FlagModel
|
|
55
|
+
except ImportError:
|
|
56
|
+
error_message = "Failed to import module 'BGEM3FlagModel'"
|
|
57
|
+
installation_guide = [
|
|
58
|
+
"Please make sure 'FlagEmbedding' is installed. ",
|
|
59
|
+
"You can install it by `pip install FlagEmbedding`\n",
|
|
60
|
+
]
|
|
61
|
+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
62
|
+
|
|
63
|
+
torch_dtype = None
|
|
64
|
+
if torch_dtype_str := self._kwargs.get("torch_dtype"):
|
|
65
|
+
try:
|
|
66
|
+
torch_dtype = getattr(torch, torch_dtype_str)
|
|
67
|
+
if torch_dtype not in [
|
|
68
|
+
torch.float16,
|
|
69
|
+
torch.float32,
|
|
70
|
+
torch.bfloat16,
|
|
71
|
+
]:
|
|
72
|
+
logger.warning(
|
|
73
|
+
f"BGE engine only support fp16, but got {torch_dtype_str}. Using default torch dtype: fp16."
|
|
74
|
+
)
|
|
75
|
+
torch_dtype = torch.float16
|
|
76
|
+
except AttributeError:
|
|
77
|
+
logger.warning(
|
|
78
|
+
f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
|
|
79
|
+
)
|
|
80
|
+
torch_dtype = torch.float16
|
|
81
|
+
|
|
82
|
+
if torch_dtype and torch_dtype == torch.float16:
|
|
83
|
+
model_kwargs = {"use_fp16": True}
|
|
84
|
+
else:
|
|
85
|
+
model_kwargs = {}
|
|
86
|
+
self._model = BGEM3FlagModel(
|
|
87
|
+
self._model_path,
|
|
88
|
+
device=self._device,
|
|
89
|
+
trust_remote_code=True,
|
|
90
|
+
return_sparse=self._return_sparse,
|
|
91
|
+
**model_kwargs,
|
|
92
|
+
)
|
|
93
|
+
self._tokenizer = self._model.tokenizer
|
|
94
|
+
|
|
95
|
+
def create_embedding(
|
|
96
|
+
self,
|
|
97
|
+
sentences: Union[str, List[str]],
|
|
98
|
+
**kwargs,
|
|
99
|
+
):
|
|
100
|
+
from FlagEmbedding import BGEM3FlagModel
|
|
101
|
+
|
|
102
|
+
# flag embed dose not have this param
|
|
103
|
+
# kwargs.setdefault("normalize_embeddings", True)
|
|
104
|
+
model_uid = kwargs.pop("model_uid", None)
|
|
105
|
+
|
|
106
|
+
@no_type_check
|
|
107
|
+
def encode(
|
|
108
|
+
model: Union[BGEM3FlagModel],
|
|
109
|
+
sentences: Union[str, List[str]],
|
|
110
|
+
batch_size: int = 32,
|
|
111
|
+
show_progress_bar: bool = None,
|
|
112
|
+
output_value: str = "sparse_embedding",
|
|
113
|
+
convert_to_numpy: bool = True,
|
|
114
|
+
convert_to_tensor: bool = False,
|
|
115
|
+
device: str = None,
|
|
116
|
+
normalize_embeddings: bool = False,
|
|
117
|
+
**kwargs,
|
|
118
|
+
):
|
|
119
|
+
"""
|
|
120
|
+
Computes sentence embeddings with bge-m3 model
|
|
121
|
+
Nothing special here, just replace sentence-transformer with FlagEmbedding
|
|
122
|
+
TODO: think about how to solve the redundant code of encode method in the future
|
|
123
|
+
|
|
124
|
+
:param sentences: the sentences to embed
|
|
125
|
+
:param batch_size: the batch size used for the computation
|
|
126
|
+
:param show_progress_bar: Output a progress bar when encode sentences
|
|
127
|
+
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
|
|
128
|
+
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
|
|
129
|
+
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
|
|
130
|
+
:param device: Which torch.device to use for the computation
|
|
131
|
+
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
|
|
132
|
+
|
|
133
|
+
:return:
|
|
134
|
+
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
|
|
135
|
+
"""
|
|
136
|
+
import torch
|
|
137
|
+
from tqdm.autonotebook import trange
|
|
138
|
+
|
|
139
|
+
if show_progress_bar is None:
|
|
140
|
+
show_progress_bar = (
|
|
141
|
+
logger.getEffectiveLevel() == logging.INFO
|
|
142
|
+
or logger.getEffectiveLevel() == logging.DEBUG
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if convert_to_tensor:
|
|
146
|
+
convert_to_numpy = False
|
|
147
|
+
|
|
148
|
+
if output_value != "sparse_embedding":
|
|
149
|
+
convert_to_tensor = False
|
|
150
|
+
convert_to_numpy = False
|
|
151
|
+
|
|
152
|
+
input_was_string = False
|
|
153
|
+
if isinstance(sentences, str) or not hasattr(
|
|
154
|
+
sentences, "__len__"
|
|
155
|
+
): # Cast an individual sentence to a list with length 1
|
|
156
|
+
sentences = [sentences]
|
|
157
|
+
input_was_string = True
|
|
158
|
+
|
|
159
|
+
if device is None:
|
|
160
|
+
device = get_available_device()
|
|
161
|
+
logger.info(f"Use pytorch device_name: {device}")
|
|
162
|
+
|
|
163
|
+
all_embeddings = []
|
|
164
|
+
|
|
165
|
+
length_sorted_idx = np.argsort(
|
|
166
|
+
[-self._text_length(sen) for sen in sentences]
|
|
167
|
+
)
|
|
168
|
+
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
|
|
169
|
+
|
|
170
|
+
for start_index in trange(
|
|
171
|
+
0,
|
|
172
|
+
len(sentences),
|
|
173
|
+
batch_size,
|
|
174
|
+
desc="Batches",
|
|
175
|
+
disable=not show_progress_bar,
|
|
176
|
+
):
|
|
177
|
+
sentences_batch = sentences_sorted[
|
|
178
|
+
start_index : start_index + batch_size
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
with torch.no_grad():
|
|
182
|
+
out_features = model.encode(sentences_batch, **kwargs)
|
|
183
|
+
|
|
184
|
+
if output_value == "token_embeddings":
|
|
185
|
+
embeddings = []
|
|
186
|
+
for token_emb, attention in zip(
|
|
187
|
+
out_features[output_value], out_features["attention_mask"]
|
|
188
|
+
):
|
|
189
|
+
last_mask_id = len(attention) - 1
|
|
190
|
+
while (
|
|
191
|
+
last_mask_id > 0 and attention[last_mask_id].item() == 0
|
|
192
|
+
):
|
|
193
|
+
last_mask_id -= 1
|
|
194
|
+
|
|
195
|
+
embeddings.append(token_emb[0 : last_mask_id + 1])
|
|
196
|
+
elif output_value is None: # Return all outputs
|
|
197
|
+
embeddings = []
|
|
198
|
+
for sent_idx in range(len(out_features["sentence_embedding"])):
|
|
199
|
+
row = {
|
|
200
|
+
name: out_features[name][sent_idx]
|
|
201
|
+
for name in out_features
|
|
202
|
+
}
|
|
203
|
+
embeddings.append(row)
|
|
204
|
+
# for sparse embedding
|
|
205
|
+
else:
|
|
206
|
+
# TODO: Here need check if we can return density_vecs and lexical_weights at the same time
|
|
207
|
+
if kwargs.get("return_sparse"):
|
|
208
|
+
embeddings = out_features["lexical_weights"]
|
|
209
|
+
else:
|
|
210
|
+
embeddings = out_features["dense_vecs"]
|
|
211
|
+
|
|
212
|
+
if convert_to_numpy:
|
|
213
|
+
embeddings = embeddings.cpu()
|
|
214
|
+
|
|
215
|
+
all_embeddings.extend(embeddings)
|
|
216
|
+
|
|
217
|
+
all_embeddings = [
|
|
218
|
+
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
if convert_to_tensor:
|
|
222
|
+
if len(all_embeddings):
|
|
223
|
+
all_embeddings = torch.stack(all_embeddings)
|
|
224
|
+
else:
|
|
225
|
+
all_embeddings = torch.Tensor()
|
|
226
|
+
elif convert_to_numpy:
|
|
227
|
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
|
228
|
+
|
|
229
|
+
if input_was_string:
|
|
230
|
+
all_embeddings = all_embeddings[0]
|
|
231
|
+
|
|
232
|
+
return all_embeddings
|
|
233
|
+
|
|
234
|
+
all_embeddings = encode(
|
|
235
|
+
self._model,
|
|
236
|
+
sentences,
|
|
237
|
+
convert_to_numpy=False,
|
|
238
|
+
**kwargs,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if isinstance(sentences, str):
|
|
242
|
+
all_embeddings = [all_embeddings]
|
|
243
|
+
embedding_list = []
|
|
244
|
+
for index, data in enumerate(all_embeddings):
|
|
245
|
+
if kwargs.get("return_sparse"):
|
|
246
|
+
embedding_list.append(
|
|
247
|
+
EmbeddingData(
|
|
248
|
+
index=index,
|
|
249
|
+
object="sparse_embedding",
|
|
250
|
+
embedding={k: float(v) for k, v in data.items()},
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
embedding_list.append(
|
|
255
|
+
EmbeddingData(
|
|
256
|
+
index=index, object="embedding", embedding=data.tolist()
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
usage = EmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
|
260
|
+
result = Embedding(
|
|
261
|
+
object=("list" if kwargs.get("return_sparse") else "dict"), # type: ignore
|
|
262
|
+
model=model_uid,
|
|
263
|
+
model_replica=self._model_uid,
|
|
264
|
+
data=embedding_list,
|
|
265
|
+
usage=usage,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# clean cache if possible
|
|
269
|
+
# TODO: support token statistics
|
|
270
|
+
self._clean_cache_if_needed(all_token_nums=0)
|
|
271
|
+
|
|
272
|
+
return result
|
|
273
|
+
|
|
274
|
+
@classmethod
|
|
275
|
+
def check_lib(cls) -> bool:
|
|
276
|
+
return importlib.util.find_spec("FlagEmbedding") is not None
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def match_json(cls, model_spec: EmbeddingModelSpec) -> bool:
|
|
280
|
+
if model_spec.model_name in FLAG_EMBEDDER_MODEL_LIST:
|
|
281
|
+
return True
|
|
282
|
+
return False
|
|
@@ -239,6 +239,30 @@
|
|
|
239
239
|
"model_id": "Alibaba-NLP/gte-Qwen2-7B-instruct",
|
|
240
240
|
"model_revision": "e26182b2122f4435e8b3ebecbf363990f409b45b"
|
|
241
241
|
},
|
|
242
|
+
{
|
|
243
|
+
"model_name": "Qwen3-Embedding-0.6B",
|
|
244
|
+
"dimensions": 1024,
|
|
245
|
+
"max_tokens": 32768,
|
|
246
|
+
"language": ["zh", "en"],
|
|
247
|
+
"model_id": "Qwen/Qwen3-Embedding-0.6B",
|
|
248
|
+
"model_revision": "744169034862c8eec56628663995004342e4e449"
|
|
249
|
+
},
|
|
250
|
+
{
|
|
251
|
+
"model_name": "Qwen3-Embedding-4B",
|
|
252
|
+
"dimensions": 2560,
|
|
253
|
+
"max_tokens": 32768,
|
|
254
|
+
"language": ["zh", "en"],
|
|
255
|
+
"model_id": "Qwen/Qwen3-Embedding-4B",
|
|
256
|
+
"model_revision": "408b81b7fab742073065d5b3661fa74c1b3ee0a1"
|
|
257
|
+
},
|
|
258
|
+
{
|
|
259
|
+
"model_name": "Qwen3-Embedding-8B",
|
|
260
|
+
"dimensions": 4096,
|
|
261
|
+
"max_tokens": 32768,
|
|
262
|
+
"language": ["zh", "en"],
|
|
263
|
+
"model_id": "Qwen/Qwen3-Embedding-8B",
|
|
264
|
+
"model_revision": "a3d38e32b9c835d5b3d0d0a3ef3c133bbea92539"
|
|
265
|
+
},
|
|
242
266
|
{
|
|
243
267
|
"model_name": "jina-embeddings-v3",
|
|
244
268
|
"dimensions": 1024,
|
|
@@ -241,6 +241,30 @@
|
|
|
241
241
|
"model_id": "iic/gte_Qwen2-7B-instruct",
|
|
242
242
|
"model_hub": "modelscope"
|
|
243
243
|
},
|
|
244
|
+
{
|
|
245
|
+
"model_name": "Qwen3-Embedding-0.6B",
|
|
246
|
+
"dimensions": 1024,
|
|
247
|
+
"max_tokens": 32768,
|
|
248
|
+
"language": ["zh", "en"],
|
|
249
|
+
"model_id": "Qwen/Qwen3-Embedding-0.6B",
|
|
250
|
+
"model_hub": "modelscope"
|
|
251
|
+
},
|
|
252
|
+
{
|
|
253
|
+
"model_name": "Qwen3-Embedding-4B",
|
|
254
|
+
"dimensions": 2560,
|
|
255
|
+
"max_tokens": 32768,
|
|
256
|
+
"language": ["zh", "en"],
|
|
257
|
+
"model_id": "Qwen/Qwen3-Embedding-4B",
|
|
258
|
+
"model_hub": "modelscope"
|
|
259
|
+
},
|
|
260
|
+
{
|
|
261
|
+
"model_name": "Qwen3-Embedding-8B",
|
|
262
|
+
"dimensions": 4096,
|
|
263
|
+
"max_tokens": 32768,
|
|
264
|
+
"language": ["zh", "en"],
|
|
265
|
+
"model_id": "Qwen/Qwen3-Embedding-8B",
|
|
266
|
+
"model_hub": "modelscope"
|
|
267
|
+
},
|
|
244
268
|
{
|
|
245
269
|
"model_name": "jina-embeddings-v3",
|
|
246
270
|
"dimensions": 1024,
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|