fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from typing import Any, Sequence, Iterable, Type
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from py_rust_stemmers import SnowballStemmer
|
|
8
|
+
from tokenizers import Tokenizer
|
|
9
|
+
|
|
10
|
+
from fastembed.common.model_description import SparseModelDescription, ModelSource
|
|
11
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
12
|
+
from fastembed.common import OnnxProvider
|
|
13
|
+
from fastembed.common.types import Device
|
|
14
|
+
from fastembed.common.utils import define_cache_dir
|
|
15
|
+
from fastembed.sparse.sparse_embedding_base import (
|
|
16
|
+
SparseEmbedding,
|
|
17
|
+
SparseTextEmbeddingBase,
|
|
18
|
+
)
|
|
19
|
+
from fastembed.sparse.utils.minicoil_encoder import Encoder
|
|
20
|
+
from fastembed.sparse.utils.sparse_vectors_converter import SparseVectorConverter, WordEmbedding
|
|
21
|
+
from fastembed.sparse.utils.vocab_resolver import VocabResolver, VocabTokenizer
|
|
22
|
+
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
MINICOIL_MODEL_FILE = "minicoil.triplet.model.npy"
|
|
26
|
+
MINICOIL_VOCAB_FILE = "minicoil.triplet.model.vocab"
|
|
27
|
+
STOPWORDS_FILE = "stopwords.txt"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
supported_minicoil_models: list[SparseModelDescription] = [
|
|
31
|
+
SparseModelDescription(
|
|
32
|
+
model="Qdrant/minicoil-v1",
|
|
33
|
+
vocab_size=19125,
|
|
34
|
+
description="Sparse embedding model, that resolves semantic meaning of the words, "
|
|
35
|
+
"while keeping exact keyword match behavior. "
|
|
36
|
+
"Based on jinaai/jina-embeddings-v2-small-en-tokens",
|
|
37
|
+
license="apache-2.0",
|
|
38
|
+
size_in_GB=0.09,
|
|
39
|
+
sources=ModelSource(hf="Qdrant/minicoil-v1"),
|
|
40
|
+
model_file="onnx/model.onnx",
|
|
41
|
+
additional_files=[
|
|
42
|
+
STOPWORDS_FILE,
|
|
43
|
+
MINICOIL_MODEL_FILE,
|
|
44
|
+
MINICOIL_VOCAB_FILE,
|
|
45
|
+
],
|
|
46
|
+
requires_idf=True,
|
|
47
|
+
),
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
_MODEL_TO_LANGUAGE = {
|
|
51
|
+
"Qdrant/minicoil-v1": "english",
|
|
52
|
+
}
|
|
53
|
+
MODEL_TO_LANGUAGE = {
|
|
54
|
+
model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_language_by_model_name(model_name: str) -> str:
|
|
59
|
+
return MODEL_TO_LANGUAGE[model_name.lower()]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
|
|
63
|
+
"""
|
|
64
|
+
MiniCOIL is a sparse embedding model, that resolves semantic meaning of the words,
|
|
65
|
+
while keeping exact keyword match behavior.
|
|
66
|
+
|
|
67
|
+
Each vocabulary token is converted into 4d component of a sparse vector, which is then weighted by the token frequency in the corpus.
|
|
68
|
+
If the token is not found in the corpus, it is treated exactly like in BM25.
|
|
69
|
+
`
|
|
70
|
+
The model is based on `jinaai/jina-embeddings-v2-small-en-tokens`
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
model_name: str,
|
|
76
|
+
cache_dir: str | None = None,
|
|
77
|
+
threads: int | None = None,
|
|
78
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
79
|
+
k: float = 1.2,
|
|
80
|
+
b: float = 0.75,
|
|
81
|
+
avg_len: float = 150.0,
|
|
82
|
+
cuda: bool | Device = Device.AUTO,
|
|
83
|
+
device_ids: list[int] | None = None,
|
|
84
|
+
lazy_load: bool = False,
|
|
85
|
+
device_id: int | None = None,
|
|
86
|
+
specific_model_path: str | None = None,
|
|
87
|
+
**kwargs: Any,
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Args:
|
|
91
|
+
model_name (str): The name of the model to use.
|
|
92
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
93
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
94
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
95
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
96
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The providers to use for onnxruntime.
|
|
97
|
+
k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
|
|
98
|
+
I.e. defines how fast the moment when additional terms stop to increase the score. Defaults to 1.2.
|
|
99
|
+
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
|
|
100
|
+
Defaults to 0.75.
|
|
101
|
+
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 150.0.
|
|
102
|
+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
103
|
+
Defaults to Device.AUTO.
|
|
104
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
105
|
+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
|
|
106
|
+
with `providers`. Defaults to None.
|
|
107
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
108
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
109
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
110
|
+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
117
|
+
self.providers = providers
|
|
118
|
+
self.lazy_load = lazy_load
|
|
119
|
+
self.device_ids = device_ids
|
|
120
|
+
self.cuda = cuda
|
|
121
|
+
self.device_id = device_id
|
|
122
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
123
|
+
|
|
124
|
+
self.k = k
|
|
125
|
+
self.b = b
|
|
126
|
+
self.avg_len = avg_len
|
|
127
|
+
|
|
128
|
+
# Initialize class attributes
|
|
129
|
+
self.tokenizer: Tokenizer | None = None
|
|
130
|
+
self.invert_vocab: dict[int, str] = {}
|
|
131
|
+
self.special_tokens: set[str] = set()
|
|
132
|
+
self.special_tokens_ids: set[int] = set()
|
|
133
|
+
self.stopwords: set[str] = set()
|
|
134
|
+
self.vocab_resolver: VocabResolver | None = None
|
|
135
|
+
self.encoder: Encoder | None = None
|
|
136
|
+
self.output_dim: int | None = None
|
|
137
|
+
self.sparse_vector_converter: SparseVectorConverter | None = None
|
|
138
|
+
|
|
139
|
+
self.model_description = self._get_model_description(model_name)
|
|
140
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
141
|
+
self._specific_model_path = specific_model_path
|
|
142
|
+
self._model_dir = self.download_model(
|
|
143
|
+
self.model_description,
|
|
144
|
+
self.cache_dir,
|
|
145
|
+
local_files_only=self._local_files_only,
|
|
146
|
+
specific_model_path=self._specific_model_path,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if not self.lazy_load:
|
|
150
|
+
self.load_onnx_model()
|
|
151
|
+
|
|
152
|
+
def load_onnx_model(self) -> None:
|
|
153
|
+
self._load_onnx_model(
|
|
154
|
+
model_dir=self._model_dir,
|
|
155
|
+
model_file=self.model_description.model_file,
|
|
156
|
+
threads=self.threads,
|
|
157
|
+
providers=self.providers,
|
|
158
|
+
cuda=self.cuda,
|
|
159
|
+
device_id=self.device_id,
|
|
160
|
+
extra_session_options=self._extra_session_options,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
assert self.tokenizer is not None
|
|
164
|
+
|
|
165
|
+
for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr]
|
|
166
|
+
self.invert_vocab[idx] = token
|
|
167
|
+
self.special_tokens = set(self.special_token_to_id.keys())
|
|
168
|
+
self.special_tokens_ids = set(self.special_token_to_id.values())
|
|
169
|
+
self.stopwords = set(self._load_stopwords(self._model_dir))
|
|
170
|
+
|
|
171
|
+
stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
|
|
172
|
+
|
|
173
|
+
self.vocab_resolver = VocabResolver(
|
|
174
|
+
tokenizer=VocabTokenizer(self.tokenizer),
|
|
175
|
+
stopwords=self.stopwords,
|
|
176
|
+
stemmer=stemmer,
|
|
177
|
+
)
|
|
178
|
+
self.vocab_resolver.load_json_vocab(str(self._model_dir / MINICOIL_VOCAB_FILE))
|
|
179
|
+
|
|
180
|
+
weights = np.load(str(self._model_dir / MINICOIL_MODEL_FILE), mmap_mode="r")
|
|
181
|
+
self.encoder = Encoder(weights)
|
|
182
|
+
self.output_dim = self.encoder.output_dim
|
|
183
|
+
|
|
184
|
+
self.sparse_vector_converter = SparseVectorConverter(
|
|
185
|
+
stopwords=self.stopwords,
|
|
186
|
+
stemmer=stemmer,
|
|
187
|
+
k=self.k,
|
|
188
|
+
b=self.b,
|
|
189
|
+
avg_len=self.avg_len,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def token_count(
|
|
193
|
+
self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
|
|
194
|
+
) -> int:
|
|
195
|
+
return self._token_count(texts, batch_size=batch_size, **kwargs)
|
|
196
|
+
|
|
197
|
+
def embed(
|
|
198
|
+
self,
|
|
199
|
+
documents: str | Iterable[str],
|
|
200
|
+
batch_size: int = 256,
|
|
201
|
+
parallel: int | None = None,
|
|
202
|
+
**kwargs: Any,
|
|
203
|
+
) -> Iterable[SparseEmbedding]:
|
|
204
|
+
"""
|
|
205
|
+
Encode a list of documents into list of embeddings.
|
|
206
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
documents: Iterator of documents or single document to embed
|
|
210
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
211
|
+
parallel:
|
|
212
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
213
|
+
If 0, use all available cores.
|
|
214
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
List of embeddings, one per document
|
|
218
|
+
"""
|
|
219
|
+
yield from self._embed_documents(
|
|
220
|
+
model_name=self.model_name,
|
|
221
|
+
cache_dir=str(self.cache_dir),
|
|
222
|
+
documents=documents,
|
|
223
|
+
batch_size=batch_size,
|
|
224
|
+
parallel=parallel,
|
|
225
|
+
providers=self.providers,
|
|
226
|
+
cuda=self.cuda,
|
|
227
|
+
device_ids=self.device_ids,
|
|
228
|
+
k=self.k,
|
|
229
|
+
b=self.b,
|
|
230
|
+
avg_len=self.avg_len,
|
|
231
|
+
is_query=False,
|
|
232
|
+
local_files_only=self._local_files_only,
|
|
233
|
+
specific_model_path=self._specific_model_path,
|
|
234
|
+
extra_session_options=self._extra_session_options,
|
|
235
|
+
**kwargs,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
239
|
+
"""
|
|
240
|
+
Encode a list of queries into list of embeddings.
|
|
241
|
+
"""
|
|
242
|
+
yield from self._embed_documents(
|
|
243
|
+
model_name=self.model_name,
|
|
244
|
+
cache_dir=str(self.cache_dir),
|
|
245
|
+
documents=query,
|
|
246
|
+
providers=self.providers,
|
|
247
|
+
cuda=self.cuda,
|
|
248
|
+
device_ids=self.device_ids,
|
|
249
|
+
k=self.k,
|
|
250
|
+
b=self.b,
|
|
251
|
+
avg_len=self.avg_len,
|
|
252
|
+
is_query=True,
|
|
253
|
+
local_files_only=self._local_files_only,
|
|
254
|
+
specific_model_path=self._specific_model_path,
|
|
255
|
+
**kwargs,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
@classmethod
|
|
259
|
+
def _load_stopwords(cls, model_dir: Path) -> list[str]:
|
|
260
|
+
stopwords_path = model_dir / STOPWORDS_FILE
|
|
261
|
+
if not stopwords_path.exists():
|
|
262
|
+
return []
|
|
263
|
+
|
|
264
|
+
with open(stopwords_path, "r") as f:
|
|
265
|
+
return f.read().splitlines()
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def _list_supported_models(cls) -> list[SparseModelDescription]:
|
|
269
|
+
"""Lists the supported models.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
|
|
273
|
+
"""
|
|
274
|
+
return supported_minicoil_models
|
|
275
|
+
|
|
276
|
+
def _post_process_onnx_output(
|
|
277
|
+
self, output: OnnxOutputContext, is_query: bool = False, **kwargs: Any
|
|
278
|
+
) -> Iterable[SparseEmbedding]:
|
|
279
|
+
if output.input_ids is None:
|
|
280
|
+
raise ValueError("input_ids must be provided for document post-processing")
|
|
281
|
+
|
|
282
|
+
assert self.vocab_resolver is not None
|
|
283
|
+
assert self.encoder is not None
|
|
284
|
+
assert self.sparse_vector_converter is not None
|
|
285
|
+
|
|
286
|
+
# Size: (batch_size, sequence_length, hidden_size)
|
|
287
|
+
embeddings = output.model_output
|
|
288
|
+
# Size: (batch_size, sequence_length)
|
|
289
|
+
assert output.attention_mask is not None
|
|
290
|
+
masks = output.attention_mask
|
|
291
|
+
|
|
292
|
+
vocab_size = self.vocab_resolver.vocab_size()
|
|
293
|
+
embedding_size = self.encoder.output_dim
|
|
294
|
+
|
|
295
|
+
# For each document we only select those embeddings that are not masked out
|
|
296
|
+
|
|
297
|
+
for i in range(embeddings.shape[0]):
|
|
298
|
+
# Size: (sequence_length, hidden_size)
|
|
299
|
+
token_embeddings = embeddings[i, masks[i] == 1]
|
|
300
|
+
|
|
301
|
+
# Size: (sequence_length)
|
|
302
|
+
token_ids: NDArray[np.int64] = output.input_ids[i, masks[i] == 1]
|
|
303
|
+
|
|
304
|
+
word_ids_array, counts, oov, forms = self.vocab_resolver.resolve_tokens(token_ids)
|
|
305
|
+
|
|
306
|
+
# Size: (1, words)
|
|
307
|
+
word_ids_array_expanded: NDArray[np.int64] = np.expand_dims(word_ids_array, axis=0)
|
|
308
|
+
|
|
309
|
+
# Size: (1, words, embedding_size)
|
|
310
|
+
token_embeddings_array: NDArray[np.float32] = np.expand_dims(token_embeddings, axis=0)
|
|
311
|
+
|
|
312
|
+
assert word_ids_array_expanded.shape[1] == token_embeddings_array.shape[1]
|
|
313
|
+
|
|
314
|
+
# Size of word_ids_mapping: (unique_words, 2) - [vocab_id, batch_id]
|
|
315
|
+
# Size of embeddings: (unique_words, embedding_size)
|
|
316
|
+
ids_mapping, minicoil_embeddings = self.encoder.forward(
|
|
317
|
+
word_ids_array_expanded, token_embeddings_array
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Size of counts: (unique_words)
|
|
321
|
+
words_ids: list[int] = ids_mapping[:, 0].tolist() # type: ignore[assignment]
|
|
322
|
+
|
|
323
|
+
sentence_result: dict[str, WordEmbedding] = {}
|
|
324
|
+
|
|
325
|
+
words = [self.vocab_resolver.lookup_word(word_id) for word_id in words_ids]
|
|
326
|
+
|
|
327
|
+
for word, word_id, emb in zip(words, words_ids, minicoil_embeddings.tolist()): # type: ignore[arg-type]
|
|
328
|
+
if word_id == 0:
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
sentence_result[word] = WordEmbedding(
|
|
332
|
+
word=word,
|
|
333
|
+
forms=forms[word],
|
|
334
|
+
count=int(counts[word_id]),
|
|
335
|
+
word_id=int(word_id),
|
|
336
|
+
embedding=emb, # type: ignore[arg-type]
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
for oov_word, count in oov.items():
|
|
340
|
+
# {
|
|
341
|
+
# "word": oov_word,
|
|
342
|
+
# "forms": [oov_word],
|
|
343
|
+
# "count": int(count),
|
|
344
|
+
# "word_id": -1,
|
|
345
|
+
# "embedding": [1]
|
|
346
|
+
# }
|
|
347
|
+
sentence_result[oov_word] = WordEmbedding(
|
|
348
|
+
word=oov_word, forms=[oov_word], count=int(count), word_id=-1, embedding=[1]
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if not is_query:
|
|
352
|
+
yield self.sparse_vector_converter.embedding_to_vector(
|
|
353
|
+
sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
yield self.sparse_vector_converter.embedding_to_vector_query(
|
|
357
|
+
sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
@classmethod
|
|
361
|
+
def _get_worker_class(cls) -> Type["MiniCoilTextEmbeddingWorker"]:
|
|
362
|
+
return MiniCoilTextEmbeddingWorker
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class MiniCoilTextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
|
|
366
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> MiniCOIL:
|
|
367
|
+
return MiniCOIL(
|
|
368
|
+
model_name=model_name,
|
|
369
|
+
cache_dir=cache_dir,
|
|
370
|
+
threads=1,
|
|
371
|
+
**kwargs,
|
|
372
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Iterable, Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from fastembed.common.model_description import SparseModelDescription
|
|
8
|
+
from fastembed.common.types import NumpyArray
|
|
9
|
+
from fastembed.common.model_management import ModelManagement
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SparseEmbedding:
|
|
14
|
+
values: NumpyArray
|
|
15
|
+
indices: NDArray[np.int64] | NDArray[np.int32]
|
|
16
|
+
|
|
17
|
+
def as_object(self) -> dict[str, NumpyArray]:
|
|
18
|
+
return {
|
|
19
|
+
"values": self.values,
|
|
20
|
+
"indices": self.indices,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
def as_dict(self) -> dict[int, float]:
|
|
24
|
+
return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding":
|
|
28
|
+
if len(data) == 0:
|
|
29
|
+
return cls(values=np.array([]), indices=np.array([]))
|
|
30
|
+
indices, values = zip(*data.items())
|
|
31
|
+
return cls(values=np.array(values), indices=np.array(indices))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SparseTextEmbeddingBase(ModelManagement[SparseModelDescription]):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model_name: str,
|
|
38
|
+
cache_dir: str | None = None,
|
|
39
|
+
threads: int | None = None,
|
|
40
|
+
**kwargs: Any,
|
|
41
|
+
):
|
|
42
|
+
self.model_name = model_name
|
|
43
|
+
self.cache_dir = cache_dir
|
|
44
|
+
self.threads = threads
|
|
45
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
46
|
+
|
|
47
|
+
def embed(
|
|
48
|
+
self,
|
|
49
|
+
documents: str | Iterable[str],
|
|
50
|
+
batch_size: int = 256,
|
|
51
|
+
parallel: int | None = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> Iterable[SparseEmbedding]:
|
|
54
|
+
raise NotImplementedError()
|
|
55
|
+
|
|
56
|
+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
57
|
+
"""
|
|
58
|
+
Embeds a list of text passages into a list of embeddings.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
texts (Iterable[str]): The list of texts to embed.
|
|
62
|
+
**kwargs: Additional keyword argument to pass to the embed method.
|
|
63
|
+
|
|
64
|
+
Yields:
|
|
65
|
+
Iterable[SparseEmbedding]: The sparse embeddings.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
69
|
+
yield from self.embed(texts, **kwargs)
|
|
70
|
+
|
|
71
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
72
|
+
"""
|
|
73
|
+
Embeds queries
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Iterable[SparseEmbedding]: The sparse embeddings.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
83
|
+
if isinstance(query, str):
|
|
84
|
+
yield from self.embed([query], **kwargs)
|
|
85
|
+
else:
|
|
86
|
+
yield from self.embed(query, **kwargs)
|
|
87
|
+
|
|
88
|
+
def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
|
|
89
|
+
"""Returns the number of tokens in the texts."""
|
|
90
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
|
|
4
|
+
from fastembed.common import OnnxProvider
|
|
5
|
+
from fastembed.common.types import Device
|
|
6
|
+
from fastembed.sparse.bm25 import Bm25
|
|
7
|
+
from fastembed.sparse.bm42 import Bm42
|
|
8
|
+
from fastembed.sparse.minicoil import MiniCOIL
|
|
9
|
+
from fastembed.sparse.sparse_embedding_base import (
|
|
10
|
+
SparseEmbedding,
|
|
11
|
+
SparseTextEmbeddingBase,
|
|
12
|
+
)
|
|
13
|
+
from fastembed.sparse.splade_pp import SpladePP
|
|
14
|
+
import warnings
|
|
15
|
+
from fastembed.common.model_description import SparseModelDescription
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SparseTextEmbedding(SparseTextEmbeddingBase):
|
|
19
|
+
EMBEDDINGS_REGISTRY: list[Type[SparseTextEmbeddingBase]] = [SpladePP, Bm42, Bm25, MiniCOIL]
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
23
|
+
"""
|
|
24
|
+
Lists the supported models.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
```
|
|
31
|
+
[
|
|
32
|
+
{
|
|
33
|
+
"model": "prithvida/SPLADE_PP_en_v1",
|
|
34
|
+
"vocab_size": 30522,
|
|
35
|
+
"description": "Independent Implementation of SPLADE++ Model for English",
|
|
36
|
+
"license": "apache-2.0",
|
|
37
|
+
"size_in_GB": 0.532,
|
|
38
|
+
"sources": {
|
|
39
|
+
"hf": "qdrant/SPLADE_PP_en_v1",
|
|
40
|
+
},
|
|
41
|
+
}
|
|
42
|
+
]
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def _list_supported_models(cls) -> list[SparseModelDescription]:
|
|
49
|
+
result: list[SparseModelDescription] = []
|
|
50
|
+
for embedding in cls.EMBEDDINGS_REGISTRY:
|
|
51
|
+
result.extend(embedding._list_supported_models())
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
model_name: str,
|
|
57
|
+
cache_dir: str | None = None,
|
|
58
|
+
threads: int | None = None,
|
|
59
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
60
|
+
cuda: bool | Device = Device.AUTO,
|
|
61
|
+
device_ids: list[int] | None = None,
|
|
62
|
+
lazy_load: bool = False,
|
|
63
|
+
**kwargs: Any,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
66
|
+
if model_name.lower() == "prithvida/Splade_PP_en_v1".lower():
|
|
67
|
+
warnings.warn(
|
|
68
|
+
"The right spelling is prithivida/Splade_PP_en_v1. "
|
|
69
|
+
"Support of this name will be removed soon, please fix the model_name",
|
|
70
|
+
DeprecationWarning,
|
|
71
|
+
stacklevel=2,
|
|
72
|
+
)
|
|
73
|
+
model_name = "prithivida/Splade_PP_en_v1"
|
|
74
|
+
|
|
75
|
+
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
|
|
76
|
+
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
|
|
77
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
78
|
+
self.model = EMBEDDING_MODEL_TYPE(
|
|
79
|
+
model_name,
|
|
80
|
+
cache_dir,
|
|
81
|
+
threads=threads,
|
|
82
|
+
providers=providers,
|
|
83
|
+
cuda=cuda,
|
|
84
|
+
device_ids=device_ids,
|
|
85
|
+
lazy_load=lazy_load,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Model {model_name} is not supported in SparseTextEmbedding."
|
|
92
|
+
"Please check the supported models using `SparseTextEmbedding.list_supported_models()`"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def embed(
|
|
96
|
+
self,
|
|
97
|
+
documents: str | Iterable[str],
|
|
98
|
+
batch_size: int = 256,
|
|
99
|
+
parallel: int | None = None,
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
) -> Iterable[SparseEmbedding]:
|
|
102
|
+
"""
|
|
103
|
+
Encode a list of documents into list of embeddings.
|
|
104
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
documents: Iterator of documents or single document to embed
|
|
108
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
109
|
+
parallel:
|
|
110
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
111
|
+
If 0, use all available cores.
|
|
112
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
List of embeddings, one per document
|
|
116
|
+
"""
|
|
117
|
+
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
|
|
118
|
+
|
|
119
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
120
|
+
"""
|
|
121
|
+
Embeds queries
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Iterable[SparseEmbedding]: The sparse embeddings.
|
|
128
|
+
"""
|
|
129
|
+
yield from self.model.query_embed(query, **kwargs)
|
|
130
|
+
|
|
131
|
+
def token_count(
|
|
132
|
+
self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
|
|
133
|
+
) -> int:
|
|
134
|
+
"""Returns the number of tokens in the texts.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
texts (str | Iterable[str]): The list of texts to embed.
|
|
138
|
+
batch_size (int): Batch size for encoding
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
int: Sum of number of tokens in the texts.
|
|
142
|
+
"""
|
|
143
|
+
return self.model.token_count(texts, batch_size=batch_size, **kwargs)
|