fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
|
|
3
|
+
from fastembed.common.types import NumpyArray, OnnxProvider, Device
|
|
4
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
5
|
+
from fastembed.common.utils import define_cache_dir, normalize
|
|
6
|
+
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
|
|
7
|
+
from fastembed.text.text_embedding_base import TextEmbeddingBase
|
|
8
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
9
|
+
|
|
10
|
+
supported_onnx_models: list[DenseModelDescription] = [
|
|
11
|
+
DenseModelDescription(
|
|
12
|
+
model="BAAI/bge-base-en",
|
|
13
|
+
dim=768,
|
|
14
|
+
description=(
|
|
15
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
16
|
+
"Prefixes for queries/documents: necessary, 2023 year."
|
|
17
|
+
),
|
|
18
|
+
license="mit",
|
|
19
|
+
size_in_GB=0.42,
|
|
20
|
+
sources=ModelSource(
|
|
21
|
+
hf="Qdrant/fast-bge-base-en",
|
|
22
|
+
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
|
|
23
|
+
_deprecated_tar_struct=True,
|
|
24
|
+
),
|
|
25
|
+
model_file="model_optimized.onnx",
|
|
26
|
+
),
|
|
27
|
+
DenseModelDescription(
|
|
28
|
+
model="BAAI/bge-base-en-v1.5",
|
|
29
|
+
dim=768,
|
|
30
|
+
description=(
|
|
31
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
32
|
+
"Prefixes for queries/documents: not so necessary, 2023 year."
|
|
33
|
+
),
|
|
34
|
+
license="mit",
|
|
35
|
+
size_in_GB=0.21,
|
|
36
|
+
sources=ModelSource(
|
|
37
|
+
hf="qdrant/bge-base-en-v1.5-onnx-q",
|
|
38
|
+
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
|
|
39
|
+
_deprecated_tar_struct=True,
|
|
40
|
+
),
|
|
41
|
+
model_file="model_optimized.onnx",
|
|
42
|
+
),
|
|
43
|
+
DenseModelDescription(
|
|
44
|
+
model="BAAI/bge-large-en-v1.5",
|
|
45
|
+
dim=1024,
|
|
46
|
+
description=(
|
|
47
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
48
|
+
"Prefixes for queries/documents: not so necessary, 2023 year."
|
|
49
|
+
),
|
|
50
|
+
license="mit",
|
|
51
|
+
size_in_GB=1.20,
|
|
52
|
+
sources=ModelSource(hf="qdrant/bge-large-en-v1.5-onnx"),
|
|
53
|
+
model_file="model.onnx",
|
|
54
|
+
),
|
|
55
|
+
DenseModelDescription(
|
|
56
|
+
model="BAAI/bge-small-en",
|
|
57
|
+
dim=384,
|
|
58
|
+
description=(
|
|
59
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
60
|
+
"Prefixes for queries/documents: necessary, 2023 year."
|
|
61
|
+
),
|
|
62
|
+
license="mit",
|
|
63
|
+
size_in_GB=0.13,
|
|
64
|
+
sources=ModelSource(
|
|
65
|
+
hf="Qdrant/bge-small-en",
|
|
66
|
+
url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
|
|
67
|
+
_deprecated_tar_struct=True,
|
|
68
|
+
),
|
|
69
|
+
model_file="model_optimized.onnx",
|
|
70
|
+
),
|
|
71
|
+
DenseModelDescription(
|
|
72
|
+
model="BAAI/bge-small-en-v1.5",
|
|
73
|
+
dim=384,
|
|
74
|
+
description=(
|
|
75
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
76
|
+
"Prefixes for queries/documents: not so necessary, 2023 year."
|
|
77
|
+
),
|
|
78
|
+
license="mit",
|
|
79
|
+
size_in_GB=0.067,
|
|
80
|
+
sources=ModelSource(hf="qdrant/bge-small-en-v1.5-onnx-q"),
|
|
81
|
+
model_file="model_optimized.onnx",
|
|
82
|
+
),
|
|
83
|
+
DenseModelDescription(
|
|
84
|
+
model="BAAI/bge-small-zh-v1.5",
|
|
85
|
+
dim=512,
|
|
86
|
+
description=(
|
|
87
|
+
"Text embeddings, Unimodal (text), Chinese, 512 input tokens truncation, "
|
|
88
|
+
"Prefixes for queries/documents: not so necessary, 2023 year."
|
|
89
|
+
),
|
|
90
|
+
license="mit",
|
|
91
|
+
size_in_GB=0.09,
|
|
92
|
+
sources=ModelSource(
|
|
93
|
+
hf="Qdrant/bge-small-zh-v1.5",
|
|
94
|
+
url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
|
|
95
|
+
_deprecated_tar_struct=True,
|
|
96
|
+
),
|
|
97
|
+
model_file="model_optimized.onnx",
|
|
98
|
+
),
|
|
99
|
+
DenseModelDescription(
|
|
100
|
+
model="mixedbread-ai/mxbai-embed-large-v1",
|
|
101
|
+
dim=1024,
|
|
102
|
+
description=(
|
|
103
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
104
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
105
|
+
),
|
|
106
|
+
license="apache-2.0",
|
|
107
|
+
size_in_GB=0.64,
|
|
108
|
+
sources=ModelSource(hf="mixedbread-ai/mxbai-embed-large-v1"),
|
|
109
|
+
model_file="onnx/model.onnx",
|
|
110
|
+
),
|
|
111
|
+
DenseModelDescription(
|
|
112
|
+
model="snowflake/snowflake-arctic-embed-xs",
|
|
113
|
+
dim=384,
|
|
114
|
+
description=(
|
|
115
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
116
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
117
|
+
),
|
|
118
|
+
license="apache-2.0",
|
|
119
|
+
size_in_GB=0.09,
|
|
120
|
+
sources=ModelSource(hf="snowflake/snowflake-arctic-embed-xs"),
|
|
121
|
+
model_file="onnx/model.onnx",
|
|
122
|
+
),
|
|
123
|
+
DenseModelDescription(
|
|
124
|
+
model="snowflake/snowflake-arctic-embed-s",
|
|
125
|
+
dim=384,
|
|
126
|
+
description=(
|
|
127
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
128
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
129
|
+
),
|
|
130
|
+
license="apache-2.0",
|
|
131
|
+
size_in_GB=0.13,
|
|
132
|
+
sources=ModelSource(hf="snowflake/snowflake-arctic-embed-s"),
|
|
133
|
+
model_file="onnx/model.onnx",
|
|
134
|
+
),
|
|
135
|
+
DenseModelDescription(
|
|
136
|
+
model="snowflake/snowflake-arctic-embed-m",
|
|
137
|
+
dim=768,
|
|
138
|
+
description=(
|
|
139
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
140
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
141
|
+
),
|
|
142
|
+
license="apache-2.0",
|
|
143
|
+
size_in_GB=0.43,
|
|
144
|
+
sources=ModelSource(hf="Snowflake/snowflake-arctic-embed-m"),
|
|
145
|
+
model_file="onnx/model.onnx",
|
|
146
|
+
),
|
|
147
|
+
DenseModelDescription(
|
|
148
|
+
model="snowflake/snowflake-arctic-embed-m-long",
|
|
149
|
+
dim=768,
|
|
150
|
+
description=(
|
|
151
|
+
"Text embeddings, Unimodal (text), English, 2048 input tokens truncation, "
|
|
152
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
153
|
+
),
|
|
154
|
+
license="apache-2.0",
|
|
155
|
+
size_in_GB=0.54,
|
|
156
|
+
sources=ModelSource(hf="snowflake/snowflake-arctic-embed-m-long"),
|
|
157
|
+
model_file="onnx/model.onnx",
|
|
158
|
+
),
|
|
159
|
+
DenseModelDescription(
|
|
160
|
+
model="snowflake/snowflake-arctic-embed-l",
|
|
161
|
+
dim=1024,
|
|
162
|
+
description=(
|
|
163
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
164
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
165
|
+
),
|
|
166
|
+
license="apache-2.0",
|
|
167
|
+
size_in_GB=1.02,
|
|
168
|
+
sources=ModelSource(hf="snowflake/snowflake-arctic-embed-l"),
|
|
169
|
+
model_file="onnx/model.onnx",
|
|
170
|
+
),
|
|
171
|
+
DenseModelDescription(
|
|
172
|
+
model="jinaai/jina-clip-v1",
|
|
173
|
+
dim=768,
|
|
174
|
+
description=(
|
|
175
|
+
"Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: "
|
|
176
|
+
"not necessary, 2024 year"
|
|
177
|
+
),
|
|
178
|
+
license="apache-2.0",
|
|
179
|
+
size_in_GB=0.55,
|
|
180
|
+
sources=ModelSource(hf="jinaai/jina-clip-v1"),
|
|
181
|
+
model_file="onnx/text_model.onnx",
|
|
182
|
+
),
|
|
183
|
+
]
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]):
|
|
187
|
+
"""Implementation of the Flag Embedding model."""
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
191
|
+
"""
|
|
192
|
+
Lists the supported models.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
196
|
+
"""
|
|
197
|
+
return supported_onnx_models
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self,
|
|
201
|
+
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
202
|
+
cache_dir: str | None = None,
|
|
203
|
+
threads: int | None = None,
|
|
204
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
205
|
+
cuda: bool | Device = Device.AUTO,
|
|
206
|
+
device_ids: list[int] | None = None,
|
|
207
|
+
lazy_load: bool = False,
|
|
208
|
+
device_id: int | None = None,
|
|
209
|
+
specific_model_path: str | None = None,
|
|
210
|
+
**kwargs: Any,
|
|
211
|
+
):
|
|
212
|
+
"""
|
|
213
|
+
Args:
|
|
214
|
+
model_name (str): The name of the model to use.
|
|
215
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
216
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
217
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
218
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
219
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
|
|
220
|
+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
|
|
221
|
+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
222
|
+
Defaults to Device.AUTO.
|
|
223
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
224
|
+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
|
|
225
|
+
with `providers`. Defaults to None.
|
|
226
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
227
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
228
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
229
|
+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
|
|
230
|
+
|
|
231
|
+
Raises:
|
|
232
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
233
|
+
"""
|
|
234
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
235
|
+
self.providers = providers
|
|
236
|
+
self.lazy_load = lazy_load
|
|
237
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
238
|
+
# List of device ids, that can be used for data parallel processing in workers
|
|
239
|
+
self.device_ids = device_ids
|
|
240
|
+
self.cuda = cuda
|
|
241
|
+
|
|
242
|
+
# This device_id will be used if we need to load model in current process
|
|
243
|
+
self.device_id: int | None = None
|
|
244
|
+
if device_id is not None:
|
|
245
|
+
self.device_id = device_id
|
|
246
|
+
elif self.device_ids is not None:
|
|
247
|
+
self.device_id = self.device_ids[0]
|
|
248
|
+
|
|
249
|
+
self.model_description = self._get_model_description(model_name)
|
|
250
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
251
|
+
self._specific_model_path = specific_model_path
|
|
252
|
+
self._model_dir = self.download_model(
|
|
253
|
+
self.model_description,
|
|
254
|
+
self.cache_dir,
|
|
255
|
+
local_files_only=self._local_files_only,
|
|
256
|
+
specific_model_path=self._specific_model_path,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if not self.lazy_load:
|
|
260
|
+
self.load_onnx_model()
|
|
261
|
+
|
|
262
|
+
def embed(
|
|
263
|
+
self,
|
|
264
|
+
documents: str | Iterable[str],
|
|
265
|
+
batch_size: int = 256,
|
|
266
|
+
parallel: int | None = None,
|
|
267
|
+
**kwargs: Any,
|
|
268
|
+
) -> Iterable[NumpyArray]:
|
|
269
|
+
"""
|
|
270
|
+
Encode a list of documents into list of embeddings.
|
|
271
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
documents: Iterator of documents or single document to embed
|
|
275
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
276
|
+
parallel:
|
|
277
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
278
|
+
If 0, use all available cores.
|
|
279
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
List of embeddings, one per document
|
|
283
|
+
"""
|
|
284
|
+
yield from self._embed_documents(
|
|
285
|
+
model_name=self.model_name,
|
|
286
|
+
cache_dir=str(self.cache_dir),
|
|
287
|
+
documents=documents,
|
|
288
|
+
batch_size=batch_size,
|
|
289
|
+
parallel=parallel,
|
|
290
|
+
providers=self.providers,
|
|
291
|
+
cuda=self.cuda,
|
|
292
|
+
device_ids=self.device_ids,
|
|
293
|
+
local_files_only=self._local_files_only,
|
|
294
|
+
specific_model_path=self._specific_model_path,
|
|
295
|
+
extra_session_options=self._extra_session_options,
|
|
296
|
+
**kwargs,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
@classmethod
|
|
300
|
+
def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]:
|
|
301
|
+
return OnnxTextEmbeddingWorker
|
|
302
|
+
|
|
303
|
+
def _preprocess_onnx_input(
|
|
304
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
305
|
+
) -> dict[str, NumpyArray]:
|
|
306
|
+
"""
|
|
307
|
+
Preprocess the onnx input.
|
|
308
|
+
"""
|
|
309
|
+
return onnx_input
|
|
310
|
+
|
|
311
|
+
def _post_process_onnx_output(
|
|
312
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
313
|
+
) -> Iterable[NumpyArray]:
|
|
314
|
+
embeddings = output.model_output
|
|
315
|
+
|
|
316
|
+
if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
|
|
317
|
+
processed_embeddings = embeddings[:, 0]
|
|
318
|
+
elif embeddings.ndim == 2: # (batch_size, embedding_dim)
|
|
319
|
+
processed_embeddings = embeddings
|
|
320
|
+
else:
|
|
321
|
+
raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
|
|
322
|
+
return normalize(processed_embeddings)
|
|
323
|
+
|
|
324
|
+
def load_onnx_model(self) -> None:
|
|
325
|
+
self._load_onnx_model(
|
|
326
|
+
model_dir=self._model_dir,
|
|
327
|
+
model_file=self.model_description.model_file,
|
|
328
|
+
threads=self.threads,
|
|
329
|
+
providers=self.providers,
|
|
330
|
+
cuda=self.cuda,
|
|
331
|
+
device_id=self.device_id,
|
|
332
|
+
extra_session_options=self._extra_session_options,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def token_count(
|
|
336
|
+
self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
|
|
337
|
+
) -> int:
|
|
338
|
+
return self._token_count(texts, batch_size=batch_size, **kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
|
|
342
|
+
def init_embedding(
|
|
343
|
+
self,
|
|
344
|
+
model_name: str,
|
|
345
|
+
cache_dir: str,
|
|
346
|
+
**kwargs: Any,
|
|
347
|
+
) -> OnnxTextEmbedding:
|
|
348
|
+
return OnnxTextEmbedding(
|
|
349
|
+
model_name=model_name,
|
|
350
|
+
cache_dir=cache_dir,
|
|
351
|
+
threads=1,
|
|
352
|
+
**kwargs,
|
|
353
|
+
)
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from multiprocessing import get_all_start_methods
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Iterable, Sequence, Type
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from tokenizers import Encoding, Tokenizer
|
|
9
|
+
|
|
10
|
+
from fastembed.common.types import NumpyArray, OnnxProvider, Device
|
|
11
|
+
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
|
|
12
|
+
from fastembed.common.preprocessor_utils import load_tokenizer
|
|
13
|
+
from fastembed.common.utils import iter_batch
|
|
14
|
+
from fastembed.parallel_processor import ParallelWorkerPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OnnxTextModel(OnnxModel[T]):
|
|
18
|
+
ONNX_OUTPUT_NAMES: list[str] | None = None
|
|
19
|
+
|
|
20
|
+
@classmethod
|
|
21
|
+
def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
|
|
22
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
23
|
+
|
|
24
|
+
def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
|
|
25
|
+
"""Post-process the ONNX model output to convert it into a usable format.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
29
|
+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Iterable[T]: Post-processed output as an iterable of type T.
|
|
33
|
+
"""
|
|
34
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
35
|
+
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.tokenizer: Tokenizer | None = None
|
|
39
|
+
self.special_token_to_id: dict[str, int] = {}
|
|
40
|
+
|
|
41
|
+
def _preprocess_onnx_input(
|
|
42
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
43
|
+
) -> dict[str, NumpyArray | NDArray[np.int64]]:
|
|
44
|
+
"""
|
|
45
|
+
Preprocess the onnx input.
|
|
46
|
+
"""
|
|
47
|
+
return onnx_input
|
|
48
|
+
|
|
49
|
+
def _load_onnx_model(
|
|
50
|
+
self,
|
|
51
|
+
model_dir: Path,
|
|
52
|
+
model_file: str,
|
|
53
|
+
threads: int | None,
|
|
54
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
55
|
+
cuda: bool | Device = Device.AUTO,
|
|
56
|
+
device_id: int | None = None,
|
|
57
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
super()._load_onnx_model(
|
|
60
|
+
model_dir=model_dir,
|
|
61
|
+
model_file=model_file,
|
|
62
|
+
threads=threads,
|
|
63
|
+
providers=providers,
|
|
64
|
+
cuda=cuda,
|
|
65
|
+
device_id=device_id,
|
|
66
|
+
extra_session_options=extra_session_options,
|
|
67
|
+
)
|
|
68
|
+
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
|
|
69
|
+
|
|
70
|
+
def load_onnx_model(self) -> None:
|
|
71
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
72
|
+
|
|
73
|
+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
|
|
74
|
+
return self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
|
|
75
|
+
|
|
76
|
+
def onnx_embed(
|
|
77
|
+
self,
|
|
78
|
+
documents: list[str],
|
|
79
|
+
**kwargs: Any,
|
|
80
|
+
) -> OnnxOutputContext:
|
|
81
|
+
encoded = self.tokenize(documents, **kwargs)
|
|
82
|
+
input_ids = np.array([e.ids for e in encoded])
|
|
83
|
+
attention_mask = np.array([e.attention_mask for e in encoded])
|
|
84
|
+
input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
|
|
85
|
+
onnx_input: dict[str, NumpyArray] = {
|
|
86
|
+
"input_ids": np.array(input_ids, dtype=np.int64),
|
|
87
|
+
}
|
|
88
|
+
if "attention_mask" in input_names:
|
|
89
|
+
onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64)
|
|
90
|
+
if "token_type_ids" in input_names:
|
|
91
|
+
onnx_input["token_type_ids"] = np.array(
|
|
92
|
+
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
|
|
93
|
+
)
|
|
94
|
+
onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
|
|
95
|
+
|
|
96
|
+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
|
|
97
|
+
return OnnxOutputContext(
|
|
98
|
+
model_output=model_output[0],
|
|
99
|
+
attention_mask=onnx_input.get("attention_mask", attention_mask),
|
|
100
|
+
input_ids=onnx_input.get("input_ids", input_ids),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _embed_documents(
|
|
104
|
+
self,
|
|
105
|
+
model_name: str,
|
|
106
|
+
cache_dir: str,
|
|
107
|
+
documents: str | Iterable[str],
|
|
108
|
+
batch_size: int = 256,
|
|
109
|
+
parallel: int | None = None,
|
|
110
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
111
|
+
cuda: bool | Device = Device.AUTO,
|
|
112
|
+
device_ids: list[int] | None = None,
|
|
113
|
+
local_files_only: bool = False,
|
|
114
|
+
specific_model_path: str | None = None,
|
|
115
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
116
|
+
**kwargs: Any,
|
|
117
|
+
) -> Iterable[T]:
|
|
118
|
+
is_small = False
|
|
119
|
+
|
|
120
|
+
if isinstance(documents, str):
|
|
121
|
+
documents = [documents]
|
|
122
|
+
is_small = True
|
|
123
|
+
|
|
124
|
+
if isinstance(documents, list):
|
|
125
|
+
if len(documents) < batch_size:
|
|
126
|
+
is_small = True
|
|
127
|
+
|
|
128
|
+
if parallel is None or is_small:
|
|
129
|
+
if not hasattr(self, "model") or self.model is None:
|
|
130
|
+
self.load_onnx_model()
|
|
131
|
+
for batch in iter_batch(documents, batch_size):
|
|
132
|
+
yield from self._post_process_onnx_output(
|
|
133
|
+
self.onnx_embed(batch, **kwargs), **kwargs
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
if parallel == 0:
|
|
137
|
+
parallel = os.cpu_count()
|
|
138
|
+
|
|
139
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
140
|
+
params = {
|
|
141
|
+
"model_name": model_name,
|
|
142
|
+
"cache_dir": cache_dir,
|
|
143
|
+
"providers": providers,
|
|
144
|
+
"local_files_only": local_files_only,
|
|
145
|
+
"specific_model_path": specific_model_path,
|
|
146
|
+
**kwargs,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
if extra_session_options is not None:
|
|
150
|
+
params.update(extra_session_options)
|
|
151
|
+
|
|
152
|
+
pool = ParallelWorkerPool(
|
|
153
|
+
num_workers=parallel or 1,
|
|
154
|
+
worker=self._get_worker_class(),
|
|
155
|
+
cuda=cuda,
|
|
156
|
+
device_ids=device_ids,
|
|
157
|
+
start_method=start_method,
|
|
158
|
+
)
|
|
159
|
+
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
|
|
160
|
+
yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore
|
|
161
|
+
|
|
162
|
+
def _token_count(self, texts: str | Iterable[str], batch_size: int = 1024, **_: Any) -> int:
|
|
163
|
+
if not hasattr(self, "model") or self.model is None:
|
|
164
|
+
self.load_onnx_model() # loads the tokenizer as well
|
|
165
|
+
|
|
166
|
+
token_num = 0
|
|
167
|
+
assert self.tokenizer is not None
|
|
168
|
+
texts = [texts] if isinstance(texts, str) else texts
|
|
169
|
+
for batch in iter_batch(texts, batch_size):
|
|
170
|
+
for tokens in self.tokenizer.encode_batch(batch):
|
|
171
|
+
token_num += sum(tokens.attention_mask)
|
|
172
|
+
|
|
173
|
+
return token_num
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TextEmbeddingWorker(EmbeddingWorker[T]):
|
|
177
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
|
|
178
|
+
for idx, batch in items:
|
|
179
|
+
onnx_output = self.model.onnx_embed(batch)
|
|
180
|
+
yield idx, onnx_output
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Any, Iterable, Type
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
|
|
6
|
+
from fastembed.common.types import NumpyArray
|
|
7
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
8
|
+
from fastembed.common.utils import mean_pooling
|
|
9
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
|
|
10
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
11
|
+
|
|
12
|
+
supported_pooled_models: list[DenseModelDescription] = [
|
|
13
|
+
DenseModelDescription(
|
|
14
|
+
model="nomic-ai/nomic-embed-text-v1.5",
|
|
15
|
+
dim=768,
|
|
16
|
+
description=(
|
|
17
|
+
"Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
|
|
18
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
19
|
+
),
|
|
20
|
+
license="apache-2.0",
|
|
21
|
+
size_in_GB=0.52,
|
|
22
|
+
sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"),
|
|
23
|
+
model_file="onnx/model.onnx",
|
|
24
|
+
),
|
|
25
|
+
DenseModelDescription(
|
|
26
|
+
model="nomic-ai/nomic-embed-text-v1.5-Q",
|
|
27
|
+
dim=768,
|
|
28
|
+
description=(
|
|
29
|
+
"Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
|
|
30
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
31
|
+
),
|
|
32
|
+
license="apache-2.0",
|
|
33
|
+
size_in_GB=0.13,
|
|
34
|
+
sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"),
|
|
35
|
+
model_file="onnx/model_quantized.onnx",
|
|
36
|
+
),
|
|
37
|
+
DenseModelDescription(
|
|
38
|
+
model="nomic-ai/nomic-embed-text-v1",
|
|
39
|
+
dim=768,
|
|
40
|
+
description=(
|
|
41
|
+
"Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
|
|
42
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
43
|
+
),
|
|
44
|
+
license="apache-2.0",
|
|
45
|
+
size_in_GB=0.52,
|
|
46
|
+
sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1"),
|
|
47
|
+
model_file="onnx/model.onnx",
|
|
48
|
+
),
|
|
49
|
+
DenseModelDescription(
|
|
50
|
+
model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
|
51
|
+
dim=384,
|
|
52
|
+
description=(
|
|
53
|
+
"Text embeddings, Unimodal (text), Multilingual (~50 languages), 512 input tokens truncation, "
|
|
54
|
+
"Prefixes for queries/documents: not necessary, 2019 year."
|
|
55
|
+
),
|
|
56
|
+
license="apache-2.0",
|
|
57
|
+
size_in_GB=0.22,
|
|
58
|
+
sources=ModelSource(hf="qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
|
|
59
|
+
model_file="model_optimized.onnx",
|
|
60
|
+
),
|
|
61
|
+
DenseModelDescription(
|
|
62
|
+
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
|
63
|
+
dim=768,
|
|
64
|
+
description=(
|
|
65
|
+
"Text embeddings, Unimodal (text), Multilingual (~50 languages), 384 input tokens truncation, "
|
|
66
|
+
"Prefixes for queries/documents: not necessary, 2021 year."
|
|
67
|
+
),
|
|
68
|
+
license="apache-2.0",
|
|
69
|
+
size_in_GB=1.00,
|
|
70
|
+
sources=ModelSource(hf="xenova/paraphrase-multilingual-mpnet-base-v2"),
|
|
71
|
+
model_file="onnx/model.onnx",
|
|
72
|
+
),
|
|
73
|
+
DenseModelDescription(
|
|
74
|
+
model="intfloat/multilingual-e5-large",
|
|
75
|
+
dim=1024,
|
|
76
|
+
description=(
|
|
77
|
+
"Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, "
|
|
78
|
+
"Prefixes for queries/documents: necessary, 2024 year."
|
|
79
|
+
),
|
|
80
|
+
license="mit",
|
|
81
|
+
size_in_GB=2.24,
|
|
82
|
+
sources=ModelSource(
|
|
83
|
+
hf="qdrant/multilingual-e5-large-onnx",
|
|
84
|
+
url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
|
|
85
|
+
_deprecated_tar_struct=True,
|
|
86
|
+
),
|
|
87
|
+
model_file="model.onnx",
|
|
88
|
+
additional_files=["model.onnx_data"],
|
|
89
|
+
),
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class PooledEmbedding(OnnxTextEmbedding):
|
|
94
|
+
@classmethod
|
|
95
|
+
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
|
|
96
|
+
return PooledEmbeddingWorker
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def mean_pooling(
|
|
100
|
+
cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
|
|
101
|
+
) -> NumpyArray:
|
|
102
|
+
return mean_pooling(model_output, attention_mask)
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
106
|
+
"""Lists the supported models.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
110
|
+
"""
|
|
111
|
+
return supported_pooled_models
|
|
112
|
+
|
|
113
|
+
def _post_process_onnx_output(
|
|
114
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
115
|
+
) -> Iterable[NumpyArray]:
|
|
116
|
+
if output.attention_mask is None:
|
|
117
|
+
raise ValueError("attention_mask must be provided for document post-processing")
|
|
118
|
+
|
|
119
|
+
embeddings = output.model_output
|
|
120
|
+
attn_mask = output.attention_mask
|
|
121
|
+
return self.mean_pooling(embeddings, attn_mask)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
|
|
125
|
+
def init_embedding(
|
|
126
|
+
self,
|
|
127
|
+
model_name: str,
|
|
128
|
+
cache_dir: str,
|
|
129
|
+
**kwargs: Any,
|
|
130
|
+
) -> OnnxTextEmbedding:
|
|
131
|
+
return PooledEmbedding(
|
|
132
|
+
model_name=model_name,
|
|
133
|
+
cache_dir=cache_dir,
|
|
134
|
+
threads=1,
|
|
135
|
+
**kwargs,
|
|
136
|
+
)
|