fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import string
|
|
2
|
+
from typing import Any, Iterable, Sequence, Type
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from tokenizers import Encoding, Tokenizer
|
|
6
|
+
|
|
7
|
+
from fastembed.common.preprocessor_utils import load_tokenizer
|
|
8
|
+
from fastembed.common.types import NumpyArray, Device
|
|
9
|
+
from fastembed.common import OnnxProvider
|
|
10
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
11
|
+
from fastembed.common.utils import define_cache_dir, iter_batch
|
|
12
|
+
from fastembed.late_interaction.late_interaction_embedding_base import (
|
|
13
|
+
LateInteractionTextEmbeddingBase,
|
|
14
|
+
)
|
|
15
|
+
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
|
|
16
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
17
|
+
|
|
18
|
+
supported_colbert_models: list[DenseModelDescription] = [
|
|
19
|
+
DenseModelDescription(
|
|
20
|
+
model="colbert-ir/colbertv2.0",
|
|
21
|
+
dim=128,
|
|
22
|
+
description="Text embeddings, Unimodal (text), English, 512 input tokens truncation, 2023 year",
|
|
23
|
+
license="mit",
|
|
24
|
+
size_in_GB=0.44,
|
|
25
|
+
sources=ModelSource(hf="colbert-ir/colbertv2.0"),
|
|
26
|
+
model_file="model.onnx",
|
|
27
|
+
),
|
|
28
|
+
DenseModelDescription(
|
|
29
|
+
model="answerdotai/answerai-colbert-small-v1",
|
|
30
|
+
dim=96,
|
|
31
|
+
description="Text embeddings, Unimodal (text), English, 512 input tokens truncation, 2024 year",
|
|
32
|
+
license="apache-2.0",
|
|
33
|
+
size_in_GB=0.13,
|
|
34
|
+
sources=ModelSource(hf="answerdotai/answerai-colbert-small-v1"),
|
|
35
|
+
model_file="vespa_colbert.onnx",
|
|
36
|
+
),
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[NumpyArray]):
|
|
41
|
+
QUERY_MARKER_TOKEN_ID = 1
|
|
42
|
+
DOCUMENT_MARKER_TOKEN_ID = 2
|
|
43
|
+
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
|
|
44
|
+
MASK_TOKEN = "[MASK]"
|
|
45
|
+
|
|
46
|
+
def _post_process_onnx_output(
|
|
47
|
+
self, output: OnnxOutputContext, is_doc: bool = True, **kwargs: Any
|
|
48
|
+
) -> Iterable[NumpyArray]:
|
|
49
|
+
if not is_doc:
|
|
50
|
+
for embedding in output.model_output:
|
|
51
|
+
yield embedding
|
|
52
|
+
else:
|
|
53
|
+
if output.input_ids is None or output.attention_mask is None:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"input_ids and attention_mask must be provided for document post-processing"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
for i, token_sequence in enumerate(output.input_ids):
|
|
59
|
+
for j, token_id in enumerate(token_sequence): # type: ignore
|
|
60
|
+
if token_id in self.skip_list or token_id == self.pad_token_id:
|
|
61
|
+
output.attention_mask[i, j] = 0
|
|
62
|
+
|
|
63
|
+
output.model_output *= np.expand_dims(output.attention_mask, 2)
|
|
64
|
+
norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
|
|
65
|
+
norm_clamped = np.maximum(norm, 1e-12)
|
|
66
|
+
output.model_output /= norm_clamped
|
|
67
|
+
|
|
68
|
+
for embedding, attention_mask in zip(output.model_output, output.attention_mask):
|
|
69
|
+
yield embedding[attention_mask == 1]
|
|
70
|
+
|
|
71
|
+
def _preprocess_onnx_input(
|
|
72
|
+
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
|
|
73
|
+
) -> dict[str, NumpyArray]:
|
|
74
|
+
marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID
|
|
75
|
+
onnx_input["input_ids"] = np.insert(
|
|
76
|
+
onnx_input["input_ids"].astype(np.int64), 1, marker_token, axis=1
|
|
77
|
+
)
|
|
78
|
+
onnx_input["attention_mask"] = np.insert(
|
|
79
|
+
onnx_input["attention_mask"].astype(np.int64), 1, 1, axis=1
|
|
80
|
+
)
|
|
81
|
+
return onnx_input
|
|
82
|
+
|
|
83
|
+
def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]:
|
|
84
|
+
return (
|
|
85
|
+
self._tokenize_documents(documents=documents)
|
|
86
|
+
if is_doc
|
|
87
|
+
else self._tokenize_query(query=next(iter(documents)))
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _tokenize_query(self, query: str) -> list[Encoding]:
|
|
91
|
+
assert self.query_tokenizer is not None
|
|
92
|
+
encoded = self.query_tokenizer.encode_batch([query])
|
|
93
|
+
return encoded
|
|
94
|
+
|
|
95
|
+
def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
|
|
96
|
+
encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
|
|
97
|
+
return encoded
|
|
98
|
+
|
|
99
|
+
def token_count(
|
|
100
|
+
self,
|
|
101
|
+
texts: str | Iterable[str],
|
|
102
|
+
batch_size: int = 1024,
|
|
103
|
+
is_doc: bool = True,
|
|
104
|
+
include_extension: bool = False,
|
|
105
|
+
**kwargs: Any,
|
|
106
|
+
) -> int:
|
|
107
|
+
if not hasattr(self, "model") or self.model is None:
|
|
108
|
+
self.load_onnx_model() # loads the tokenizer as well
|
|
109
|
+
token_num = 0
|
|
110
|
+
texts = [texts] if isinstance(texts, str) else texts
|
|
111
|
+
tokenizer = self.tokenizer if is_doc else self.query_tokenizer
|
|
112
|
+
assert tokenizer is not None
|
|
113
|
+
for batch in iter_batch(texts, batch_size):
|
|
114
|
+
for tokens in tokenizer.encode_batch(batch):
|
|
115
|
+
if is_doc:
|
|
116
|
+
token_num += sum(tokens.attention_mask)
|
|
117
|
+
else:
|
|
118
|
+
attend_count = sum(tokens.attention_mask)
|
|
119
|
+
if include_extension:
|
|
120
|
+
token_num += max(attend_count, self.MIN_QUERY_LENGTH)
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
token_num += attend_count
|
|
124
|
+
if include_extension:
|
|
125
|
+
token_num += len(
|
|
126
|
+
batch
|
|
127
|
+
) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID
|
|
128
|
+
|
|
129
|
+
return token_num
|
|
130
|
+
|
|
131
|
+
@classmethod
|
|
132
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
133
|
+
"""Lists the supported models.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
137
|
+
"""
|
|
138
|
+
return supported_colbert_models
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
model_name: str,
|
|
143
|
+
cache_dir: str | None = None,
|
|
144
|
+
threads: int | None = None,
|
|
145
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
146
|
+
cuda: bool | Device = Device.AUTO,
|
|
147
|
+
device_ids: list[int] | None = None,
|
|
148
|
+
lazy_load: bool = False,
|
|
149
|
+
device_id: int | None = None,
|
|
150
|
+
specific_model_path: str | None = None,
|
|
151
|
+
**kwargs: Any,
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Args:
|
|
155
|
+
model_name (str): The name of the model to use.
|
|
156
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
157
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
158
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
159
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
160
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
|
|
161
|
+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
|
|
162
|
+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
163
|
+
Defaults to Device.AUTO.
|
|
164
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
165
|
+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
|
|
166
|
+
with `providers`. Defaults to None.
|
|
167
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
168
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
169
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
170
|
+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
177
|
+
self.providers = providers
|
|
178
|
+
self.lazy_load = lazy_load
|
|
179
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
180
|
+
|
|
181
|
+
# List of device ids, that can be used for data parallel processing in workers
|
|
182
|
+
self.device_ids = device_ids
|
|
183
|
+
self.cuda = cuda
|
|
184
|
+
|
|
185
|
+
# This device_id will be used if we need to load model in current process
|
|
186
|
+
self.device_id: int | None = None
|
|
187
|
+
if device_id is not None:
|
|
188
|
+
self.device_id = device_id
|
|
189
|
+
elif self.device_ids is not None:
|
|
190
|
+
self.device_id = self.device_ids[0]
|
|
191
|
+
|
|
192
|
+
self.model_description = self._get_model_description(model_name)
|
|
193
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
194
|
+
|
|
195
|
+
self._specific_model_path = specific_model_path
|
|
196
|
+
self._model_dir = self.download_model(
|
|
197
|
+
self.model_description,
|
|
198
|
+
self.cache_dir,
|
|
199
|
+
local_files_only=self._local_files_only,
|
|
200
|
+
specific_model_path=self._specific_model_path,
|
|
201
|
+
)
|
|
202
|
+
self.mask_token_id: int | None = None
|
|
203
|
+
self.pad_token_id: int | None = None
|
|
204
|
+
self.skip_list: set[int] = set()
|
|
205
|
+
|
|
206
|
+
self.query_tokenizer: Tokenizer | None = None
|
|
207
|
+
|
|
208
|
+
if not self.lazy_load:
|
|
209
|
+
self.load_onnx_model()
|
|
210
|
+
|
|
211
|
+
def load_onnx_model(self) -> None:
|
|
212
|
+
self._load_onnx_model(
|
|
213
|
+
model_dir=self._model_dir,
|
|
214
|
+
model_file=self.model_description.model_file,
|
|
215
|
+
threads=self.threads,
|
|
216
|
+
providers=self.providers,
|
|
217
|
+
cuda=self.cuda,
|
|
218
|
+
device_id=self.device_id,
|
|
219
|
+
extra_session_options=self._extra_session_options,
|
|
220
|
+
)
|
|
221
|
+
self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir)
|
|
222
|
+
|
|
223
|
+
assert self.tokenizer is not None
|
|
224
|
+
self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
|
|
225
|
+
self.pad_token_id = self.tokenizer.padding["pad_id"]
|
|
226
|
+
self.skip_list = {
|
|
227
|
+
self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
|
|
228
|
+
for symbol in string.punctuation
|
|
229
|
+
}
|
|
230
|
+
current_max_length = self.tokenizer.truncation["max_length"]
|
|
231
|
+
# ensure not to overflow after adding document-marker
|
|
232
|
+
self.tokenizer.enable_truncation(max_length=current_max_length - 1)
|
|
233
|
+
self.query_tokenizer.enable_truncation(max_length=current_max_length - 1)
|
|
234
|
+
self.query_tokenizer.enable_padding(
|
|
235
|
+
pad_token=self.MASK_TOKEN,
|
|
236
|
+
pad_id=self.mask_token_id,
|
|
237
|
+
length=self.MIN_QUERY_LENGTH,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def embed(
|
|
241
|
+
self,
|
|
242
|
+
documents: str | Iterable[str],
|
|
243
|
+
batch_size: int = 256,
|
|
244
|
+
parallel: int | None = None,
|
|
245
|
+
**kwargs: Any,
|
|
246
|
+
) -> Iterable[NumpyArray]:
|
|
247
|
+
"""
|
|
248
|
+
Encode a list of documents into list of embeddings.
|
|
249
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
documents: Iterator of documents or single document to embed
|
|
253
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
254
|
+
parallel:
|
|
255
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
256
|
+
If 0, use all available cores.
|
|
257
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of embeddings, one per document
|
|
261
|
+
"""
|
|
262
|
+
yield from self._embed_documents(
|
|
263
|
+
model_name=self.model_name,
|
|
264
|
+
cache_dir=str(self.cache_dir),
|
|
265
|
+
documents=documents,
|
|
266
|
+
batch_size=batch_size,
|
|
267
|
+
parallel=parallel,
|
|
268
|
+
providers=self.providers,
|
|
269
|
+
cuda=self.cuda,
|
|
270
|
+
device_ids=self.device_ids,
|
|
271
|
+
local_files_only=self._local_files_only,
|
|
272
|
+
specific_model_path=self._specific_model_path,
|
|
273
|
+
extra_session_options=self._extra_session_options,
|
|
274
|
+
**kwargs,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
278
|
+
if isinstance(query, str):
|
|
279
|
+
query = [query]
|
|
280
|
+
|
|
281
|
+
if not hasattr(self, "model") or self.model is None:
|
|
282
|
+
self.load_onnx_model()
|
|
283
|
+
|
|
284
|
+
for text in query:
|
|
285
|
+
yield from self._post_process_onnx_output(
|
|
286
|
+
self.onnx_embed([text], is_doc=False), is_doc=False
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
@classmethod
|
|
290
|
+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
|
|
291
|
+
return ColbertEmbeddingWorker
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class ColbertEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
|
|
295
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Colbert:
|
|
296
|
+
return Colbert(
|
|
297
|
+
model_name=model_name,
|
|
298
|
+
cache_dir=cache_dir,
|
|
299
|
+
threads=1,
|
|
300
|
+
**kwargs,
|
|
301
|
+
)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Any, Type
|
|
2
|
+
|
|
3
|
+
from fastembed.common.types import NumpyArray
|
|
4
|
+
from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker
|
|
5
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
6
|
+
|
|
7
|
+
supported_jina_colbert_models: list[DenseModelDescription] = [
|
|
8
|
+
DenseModelDescription(
|
|
9
|
+
model="jinaai/jina-colbert-v2",
|
|
10
|
+
dim=128,
|
|
11
|
+
description="New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year",
|
|
12
|
+
license="cc-by-nc-4.0",
|
|
13
|
+
size_in_GB=2.24,
|
|
14
|
+
sources=ModelSource(hf="jinaai/jina-colbert-v2"),
|
|
15
|
+
model_file="onnx/model.onnx",
|
|
16
|
+
additional_files=["onnx/model.onnx_data"],
|
|
17
|
+
)
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class JinaColbert(Colbert):
|
|
22
|
+
QUERY_MARKER_TOKEN_ID = 250002
|
|
23
|
+
DOCUMENT_MARKER_TOKEN_ID = 250003
|
|
24
|
+
MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
|
|
25
|
+
MASK_TOKEN = "<mask>"
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]:
|
|
29
|
+
return JinaColbertEmbeddingWorker
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
33
|
+
"""Lists the supported models.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
37
|
+
"""
|
|
38
|
+
return supported_jina_colbert_models
|
|
39
|
+
|
|
40
|
+
def _preprocess_onnx_input(
|
|
41
|
+
self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
|
|
42
|
+
) -> dict[str, NumpyArray]:
|
|
43
|
+
onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)
|
|
44
|
+
|
|
45
|
+
# the attention mask for jina-colbert-v2 is always 1 in queries
|
|
46
|
+
if not is_doc:
|
|
47
|
+
onnx_input["attention_mask"][:] = 1
|
|
48
|
+
return onnx_input
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class JinaColbertEmbeddingWorker(ColbertEmbeddingWorker):
|
|
52
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> JinaColbert:
|
|
53
|
+
return JinaColbert(
|
|
54
|
+
model_name=model_name,
|
|
55
|
+
cache_dir=cache_dir,
|
|
56
|
+
threads=1,
|
|
57
|
+
**kwargs,
|
|
58
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from typing import Iterable, Any
|
|
2
|
+
|
|
3
|
+
from fastembed.common.model_description import DenseModelDescription
|
|
4
|
+
from fastembed.common.types import NumpyArray
|
|
5
|
+
from fastembed.common.model_management import ModelManagement
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LateInteractionTextEmbeddingBase(ModelManagement[DenseModelDescription]):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
model_name: str,
|
|
12
|
+
cache_dir: str | None = None,
|
|
13
|
+
threads: int | None = None,
|
|
14
|
+
**kwargs: Any,
|
|
15
|
+
):
|
|
16
|
+
self.model_name = model_name
|
|
17
|
+
self.cache_dir = cache_dir
|
|
18
|
+
self.threads = threads
|
|
19
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
20
|
+
self._embedding_size: int | None = None
|
|
21
|
+
|
|
22
|
+
def embed(
|
|
23
|
+
self,
|
|
24
|
+
documents: str | Iterable[str],
|
|
25
|
+
batch_size: int = 256,
|
|
26
|
+
parallel: int | None = None,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
) -> Iterable[NumpyArray]:
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
32
|
+
"""
|
|
33
|
+
Embeds a list of text passages into a list of embeddings.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
texts (Iterable[str]): The list of texts to embed.
|
|
37
|
+
**kwargs: Additional keyword argument to pass to the embed method.
|
|
38
|
+
|
|
39
|
+
Yields:
|
|
40
|
+
Iterable[NdArray]: The embeddings.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
44
|
+
yield from self.embed(texts, **kwargs)
|
|
45
|
+
|
|
46
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
47
|
+
"""
|
|
48
|
+
Embeds queries
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Iterable[NdArray]: The embeddings.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
58
|
+
if isinstance(query, str):
|
|
59
|
+
yield from self.embed([query], **kwargs)
|
|
60
|
+
else:
|
|
61
|
+
yield from self.embed(query, **kwargs)
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
65
|
+
"""Returns embedding size of the chosen model."""
|
|
66
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def embedding_size(self) -> int:
|
|
70
|
+
"""Returns embedding size for the current model"""
|
|
71
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
72
|
+
|
|
73
|
+
def token_count(
|
|
74
|
+
self,
|
|
75
|
+
texts: str | Iterable[str],
|
|
76
|
+
batch_size: int = 1024,
|
|
77
|
+
**kwargs: Any,
|
|
78
|
+
) -> int:
|
|
79
|
+
"""Returns the number of tokens in the texts."""
|
|
80
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
|
|
4
|
+
from fastembed.common.model_description import DenseModelDescription
|
|
5
|
+
from fastembed.common.types import NumpyArray, Device
|
|
6
|
+
from fastembed.common import OnnxProvider
|
|
7
|
+
from fastembed.late_interaction.colbert import Colbert
|
|
8
|
+
from fastembed.late_interaction.jina_colbert import JinaColbert
|
|
9
|
+
from fastembed.late_interaction.late_interaction_embedding_base import (
|
|
10
|
+
LateInteractionTextEmbeddingBase,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase):
|
|
15
|
+
EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert]
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
19
|
+
"""
|
|
20
|
+
Lists the supported models.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
```
|
|
27
|
+
[
|
|
28
|
+
{
|
|
29
|
+
"model": "colbert-ir/colbertv2.0",
|
|
30
|
+
"dim": 128,
|
|
31
|
+
"description": "Late interaction model",
|
|
32
|
+
"license": "mit",
|
|
33
|
+
"size_in_GB": 0.44,
|
|
34
|
+
"sources": {
|
|
35
|
+
"hf": "colbert-ir/colbertv2.0",
|
|
36
|
+
},
|
|
37
|
+
"model_file": "model.onnx",
|
|
38
|
+
},
|
|
39
|
+
]
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
42
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
46
|
+
result: list[DenseModelDescription] = []
|
|
47
|
+
for embedding in cls.EMBEDDINGS_REGISTRY:
|
|
48
|
+
result.extend(embedding._list_supported_models())
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
model_name: str,
|
|
54
|
+
cache_dir: str | None = None,
|
|
55
|
+
threads: int | None = None,
|
|
56
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
57
|
+
cuda: bool | Device = Device.AUTO,
|
|
58
|
+
device_ids: list[int] | None = None,
|
|
59
|
+
lazy_load: bool = False,
|
|
60
|
+
**kwargs: Any,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
63
|
+
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
|
|
64
|
+
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
|
|
65
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
66
|
+
self.model = EMBEDDING_MODEL_TYPE(
|
|
67
|
+
model_name,
|
|
68
|
+
cache_dir,
|
|
69
|
+
threads=threads,
|
|
70
|
+
providers=providers,
|
|
71
|
+
cuda=cuda,
|
|
72
|
+
device_ids=device_ids,
|
|
73
|
+
lazy_load=lazy_load,
|
|
74
|
+
**kwargs,
|
|
75
|
+
)
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Model {model_name} is not supported in LateInteractionTextEmbedding."
|
|
80
|
+
"Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def embedding_size(self) -> int:
|
|
85
|
+
"""Get the embedding size of the current model"""
|
|
86
|
+
if self._embedding_size is None:
|
|
87
|
+
self._embedding_size = self.get_embedding_size(self.model_name)
|
|
88
|
+
return self._embedding_size
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
92
|
+
"""Get the embedding size of the passed model
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_name (str): The name of the model to get embedding size for.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
int: The size of the embedding.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: If the model name is not found in the supported models.
|
|
102
|
+
"""
|
|
103
|
+
descriptions = cls._list_supported_models()
|
|
104
|
+
embedding_size: int | None = None
|
|
105
|
+
for description in descriptions:
|
|
106
|
+
if description.model.lower() == model_name.lower():
|
|
107
|
+
embedding_size = description.dim
|
|
108
|
+
break
|
|
109
|
+
if embedding_size is None:
|
|
110
|
+
model_names = [description.model for description in descriptions]
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Embedding size for model {model_name} was None. "
|
|
113
|
+
f"Available model names: {model_names}"
|
|
114
|
+
)
|
|
115
|
+
return embedding_size
|
|
116
|
+
|
|
117
|
+
def embed(
|
|
118
|
+
self,
|
|
119
|
+
documents: str | Iterable[str],
|
|
120
|
+
batch_size: int = 256,
|
|
121
|
+
parallel: int | None = None,
|
|
122
|
+
**kwargs: Any,
|
|
123
|
+
) -> Iterable[NumpyArray]:
|
|
124
|
+
"""
|
|
125
|
+
Encode a list of documents into list of embeddings.
|
|
126
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
documents: Iterator of documents or single document to embed
|
|
130
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
131
|
+
parallel:
|
|
132
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
133
|
+
If 0, use all available cores.
|
|
134
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
List of embeddings, one per document
|
|
138
|
+
"""
|
|
139
|
+
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
|
|
140
|
+
|
|
141
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
142
|
+
"""
|
|
143
|
+
Embeds queries
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Iterable[NdArray]: The embeddings.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
153
|
+
yield from self.model.query_embed(query, **kwargs)
|
|
154
|
+
|
|
155
|
+
def token_count(
|
|
156
|
+
self,
|
|
157
|
+
texts: str | Iterable[str],
|
|
158
|
+
batch_size: int = 1024,
|
|
159
|
+
is_doc: bool = True,
|
|
160
|
+
include_extension: bool = False,
|
|
161
|
+
**kwargs: Any,
|
|
162
|
+
) -> int:
|
|
163
|
+
"""Returns the number of tokens in the texts.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
texts (str | Iterable[str]): The list of texts to embed.
|
|
167
|
+
batch_size (int): Batch size for encoding
|
|
168
|
+
is_doc (bool): Whether the texts are documents (disable embedding a query with include_mask=True).
|
|
169
|
+
include_extension (bool): Turn on to count DOC / QUERY marker tokens, and [MASK] token in query mode.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
int: Sum of number of tokens in the texts.
|
|
173
|
+
"""
|
|
174
|
+
return self.model.token_count(
|
|
175
|
+
texts,
|
|
176
|
+
batch_size=batch_size,
|
|
177
|
+
is_doc=is_doc,
|
|
178
|
+
include_extension=include_extension,
|
|
179
|
+
**kwargs,
|
|
180
|
+
)
|