fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
fastembed/sparse/bm25.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from multiprocessing import get_all_start_methods
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterable, Type
|
|
6
|
+
|
|
7
|
+
import mmh3
|
|
8
|
+
import numpy as np
|
|
9
|
+
from py_rust_stemmers import SnowballStemmer
|
|
10
|
+
from fastembed.common.utils import (
|
|
11
|
+
define_cache_dir,
|
|
12
|
+
iter_batch,
|
|
13
|
+
get_all_punctuation,
|
|
14
|
+
remove_non_alphanumeric,
|
|
15
|
+
)
|
|
16
|
+
from fastembed.parallel_processor import ParallelWorkerPool, Worker
|
|
17
|
+
from fastembed.sparse.sparse_embedding_base import (
|
|
18
|
+
SparseEmbedding,
|
|
19
|
+
SparseTextEmbeddingBase,
|
|
20
|
+
)
|
|
21
|
+
from fastembed.sparse.utils.tokenizer import SimpleTokenizer
|
|
22
|
+
from fastembed.common.model_description import SparseModelDescription, ModelSource
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
supported_languages = [
|
|
26
|
+
"arabic",
|
|
27
|
+
"danish",
|
|
28
|
+
"dutch",
|
|
29
|
+
"english",
|
|
30
|
+
"finnish",
|
|
31
|
+
"french",
|
|
32
|
+
"german",
|
|
33
|
+
"greek",
|
|
34
|
+
"hungarian",
|
|
35
|
+
"italian",
|
|
36
|
+
"norwegian",
|
|
37
|
+
"portuguese",
|
|
38
|
+
"romanian",
|
|
39
|
+
"russian",
|
|
40
|
+
"spanish",
|
|
41
|
+
"swedish",
|
|
42
|
+
"tamil",
|
|
43
|
+
"turkish",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
supported_bm25_models: list[SparseModelDescription] = [
|
|
47
|
+
SparseModelDescription(
|
|
48
|
+
model="Qdrant/bm25",
|
|
49
|
+
vocab_size=0,
|
|
50
|
+
description="BM25 as sparse embeddings meant to be used with Qdrant",
|
|
51
|
+
license="apache-2.0",
|
|
52
|
+
size_in_GB=0.01,
|
|
53
|
+
sources=ModelSource(hf="Qdrant/bm25"),
|
|
54
|
+
additional_files=[f"{lang}.txt" for lang in supported_languages],
|
|
55
|
+
requires_idf=True,
|
|
56
|
+
model_file="mock.file",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Bm25(SparseTextEmbeddingBase):
|
|
62
|
+
"""Implements traditional BM25 in a form of sparse embeddings.
|
|
63
|
+
Uses a count of tokens in the document to evaluate the importance of the token.
|
|
64
|
+
|
|
65
|
+
WARNING: This model is expected to be used with `modifier="idf"` in the sparse vector index of Qdrant.
|
|
66
|
+
|
|
67
|
+
BM25 formula:
|
|
68
|
+
|
|
69
|
+
score(q, d) = SUM[ IDF(q_i) * (f(q_i, d) * (k + 1)) / (f(q_i, d) + k * (1 - b + b * (|d| / avg_len))) ],
|
|
70
|
+
|
|
71
|
+
where IDF is the inverse document frequency, computed on Qdrant's side
|
|
72
|
+
f(q_i, d) is the term frequency of the token q_i in the document d
|
|
73
|
+
k, b, avg_len are hyperparameters, described below.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model_name (str): The name of the model to use.
|
|
77
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
78
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
79
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
80
|
+
k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
|
|
81
|
+
I.e. defines how fast the moment when additional terms stop to increase the score. Defaults to 1.2.
|
|
82
|
+
b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
|
|
83
|
+
Defaults to 0.75.
|
|
84
|
+
avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
|
|
85
|
+
language (str): Specifies the language for the stemmer.
|
|
86
|
+
disable_stemmer (bool): Disable the stemmer.
|
|
87
|
+
Raises:
|
|
88
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
model_name: str,
|
|
94
|
+
cache_dir: str | None = None,
|
|
95
|
+
k: float = 1.2,
|
|
96
|
+
b: float = 0.75,
|
|
97
|
+
avg_len: float = 256.0,
|
|
98
|
+
language: str = "english",
|
|
99
|
+
token_max_length: int = 40,
|
|
100
|
+
disable_stemmer: bool = False,
|
|
101
|
+
specific_model_path: str | None = None,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
):
|
|
104
|
+
super().__init__(model_name, cache_dir, **kwargs)
|
|
105
|
+
|
|
106
|
+
if language not in supported_languages:
|
|
107
|
+
raise ValueError(f"{language} language is not supported")
|
|
108
|
+
else:
|
|
109
|
+
self.language = language
|
|
110
|
+
|
|
111
|
+
self.k = k
|
|
112
|
+
self.b = b
|
|
113
|
+
self.avg_len = avg_len
|
|
114
|
+
|
|
115
|
+
model_description = self._get_model_description(model_name)
|
|
116
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
117
|
+
|
|
118
|
+
self._specific_model_path = specific_model_path
|
|
119
|
+
self._model_dir = self.download_model(
|
|
120
|
+
model_description,
|
|
121
|
+
self.cache_dir,
|
|
122
|
+
local_files_only=self._local_files_only,
|
|
123
|
+
specific_model_path=self._specific_model_path,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self.token_max_length = token_max_length
|
|
127
|
+
self.punctuation = set(get_all_punctuation())
|
|
128
|
+
self.disable_stemmer = disable_stemmer
|
|
129
|
+
|
|
130
|
+
if disable_stemmer:
|
|
131
|
+
self.stopwords: set[str] = set()
|
|
132
|
+
self.stemmer = None
|
|
133
|
+
else:
|
|
134
|
+
self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
|
|
135
|
+
self.stemmer = SnowballStemmer(language)
|
|
136
|
+
|
|
137
|
+
self.tokenizer = SimpleTokenizer
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def _list_supported_models(cls) -> list[SparseModelDescription]:
|
|
141
|
+
"""Lists the supported models.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
|
|
145
|
+
"""
|
|
146
|
+
return supported_bm25_models
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]:
|
|
150
|
+
stopwords_path = model_dir / f"{language}.txt"
|
|
151
|
+
if not stopwords_path.exists():
|
|
152
|
+
return []
|
|
153
|
+
|
|
154
|
+
with open(stopwords_path, "r") as f:
|
|
155
|
+
return f.read().splitlines()
|
|
156
|
+
|
|
157
|
+
def _embed_documents(
|
|
158
|
+
self,
|
|
159
|
+
model_name: str,
|
|
160
|
+
cache_dir: str,
|
|
161
|
+
documents: str | Iterable[str],
|
|
162
|
+
batch_size: int = 256,
|
|
163
|
+
parallel: int | None = None,
|
|
164
|
+
local_files_only: bool = False,
|
|
165
|
+
specific_model_path: str | None = None,
|
|
166
|
+
) -> Iterable[SparseEmbedding]:
|
|
167
|
+
is_small = False
|
|
168
|
+
|
|
169
|
+
if isinstance(documents, str):
|
|
170
|
+
documents = [documents]
|
|
171
|
+
is_small = True
|
|
172
|
+
|
|
173
|
+
if isinstance(documents, list):
|
|
174
|
+
if len(documents) < batch_size:
|
|
175
|
+
is_small = True
|
|
176
|
+
|
|
177
|
+
if parallel is None or is_small:
|
|
178
|
+
for batch in iter_batch(documents, batch_size):
|
|
179
|
+
yield from self.raw_embed(batch)
|
|
180
|
+
else:
|
|
181
|
+
if parallel == 0:
|
|
182
|
+
parallel = os.cpu_count()
|
|
183
|
+
|
|
184
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
185
|
+
params = {
|
|
186
|
+
"model_name": model_name,
|
|
187
|
+
"cache_dir": cache_dir,
|
|
188
|
+
"k": self.k,
|
|
189
|
+
"b": self.b,
|
|
190
|
+
"avg_len": self.avg_len,
|
|
191
|
+
"language": self.language,
|
|
192
|
+
"token_max_length": self.token_max_length,
|
|
193
|
+
"disable_stemmer": self.disable_stemmer,
|
|
194
|
+
"local_files_only": local_files_only,
|
|
195
|
+
"specific_model_path": specific_model_path,
|
|
196
|
+
}
|
|
197
|
+
pool = ParallelWorkerPool(
|
|
198
|
+
num_workers=parallel or 1,
|
|
199
|
+
worker=self._get_worker_class(),
|
|
200
|
+
start_method=start_method,
|
|
201
|
+
)
|
|
202
|
+
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
|
|
203
|
+
for record in batch:
|
|
204
|
+
yield record # type: ignore
|
|
205
|
+
|
|
206
|
+
def embed(
|
|
207
|
+
self,
|
|
208
|
+
documents: str | Iterable[str],
|
|
209
|
+
batch_size: int = 256,
|
|
210
|
+
parallel: int | None = None,
|
|
211
|
+
**kwargs: Any,
|
|
212
|
+
) -> Iterable[SparseEmbedding]:
|
|
213
|
+
"""
|
|
214
|
+
Encode a list of documents into list of embeddings.
|
|
215
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
documents: Iterator of documents or single document to embed
|
|
219
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
220
|
+
parallel:
|
|
221
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
222
|
+
If 0, use all available cores.
|
|
223
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
List of embeddings, one per document
|
|
227
|
+
"""
|
|
228
|
+
yield from self._embed_documents(
|
|
229
|
+
model_name=self.model_name,
|
|
230
|
+
cache_dir=str(self.cache_dir),
|
|
231
|
+
documents=documents,
|
|
232
|
+
batch_size=batch_size,
|
|
233
|
+
parallel=parallel,
|
|
234
|
+
local_files_only=self._local_files_only,
|
|
235
|
+
specific_model_path=self._specific_model_path,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _stem(self, tokens: list[str]) -> list[str]:
|
|
239
|
+
stemmed_tokens: list[str] = []
|
|
240
|
+
for token in tokens:
|
|
241
|
+
lower_token = token.lower()
|
|
242
|
+
|
|
243
|
+
if token in self.punctuation:
|
|
244
|
+
continue
|
|
245
|
+
|
|
246
|
+
if lower_token in self.stopwords:
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
if len(token) > self.token_max_length:
|
|
250
|
+
continue
|
|
251
|
+
|
|
252
|
+
stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token
|
|
253
|
+
|
|
254
|
+
if stemmed_token:
|
|
255
|
+
stemmed_tokens.append(stemmed_token)
|
|
256
|
+
return stemmed_tokens
|
|
257
|
+
|
|
258
|
+
def raw_embed(
|
|
259
|
+
self,
|
|
260
|
+
documents: list[str],
|
|
261
|
+
) -> list[SparseEmbedding]:
|
|
262
|
+
embeddings: list[SparseEmbedding] = []
|
|
263
|
+
for document in documents:
|
|
264
|
+
document = remove_non_alphanumeric(document)
|
|
265
|
+
tokens = self.tokenizer.tokenize(document)
|
|
266
|
+
stemmed_tokens = self._stem(tokens)
|
|
267
|
+
token_id2value = self._term_frequency(stemmed_tokens)
|
|
268
|
+
embeddings.append(SparseEmbedding.from_dict(token_id2value))
|
|
269
|
+
return embeddings
|
|
270
|
+
|
|
271
|
+
def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
|
|
272
|
+
token_num = 0
|
|
273
|
+
texts = [texts] if isinstance(texts, str) else texts
|
|
274
|
+
for text in texts:
|
|
275
|
+
document = remove_non_alphanumeric(text)
|
|
276
|
+
tokens = self.tokenizer.tokenize(document)
|
|
277
|
+
token_num += len(tokens)
|
|
278
|
+
return token_num
|
|
279
|
+
|
|
280
|
+
def _term_frequency(self, tokens: list[str]) -> dict[int, float]:
|
|
281
|
+
"""Calculate the term frequency part of the BM25 formula.
|
|
282
|
+
|
|
283
|
+
(
|
|
284
|
+
f(q_i, d) * (k + 1)
|
|
285
|
+
) / (
|
|
286
|
+
f(q_i, d) + k * (1 - b + b * (|d| / avg_len))
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
tokens (list[str]): The list of tokens in the document.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
dict[int, float]: The token_id to term frequency mapping.
|
|
294
|
+
"""
|
|
295
|
+
tf_map: dict[int, float] = {}
|
|
296
|
+
counter: defaultdict[str, int] = defaultdict(int)
|
|
297
|
+
for stemmed_token in tokens:
|
|
298
|
+
counter[stemmed_token] += 1
|
|
299
|
+
|
|
300
|
+
doc_len = len(tokens)
|
|
301
|
+
for stemmed_token in counter:
|
|
302
|
+
token_id = self.compute_token_id(stemmed_token)
|
|
303
|
+
num_occurrences = counter[stemmed_token]
|
|
304
|
+
tf_map[token_id] = num_occurrences * (self.k + 1)
|
|
305
|
+
tf_map[token_id] /= num_occurrences + self.k * (
|
|
306
|
+
1 - self.b + self.b * doc_len / self.avg_len
|
|
307
|
+
)
|
|
308
|
+
return tf_map
|
|
309
|
+
|
|
310
|
+
@classmethod
|
|
311
|
+
def compute_token_id(cls, token: str) -> int:
|
|
312
|
+
return abs(mmh3.hash(token))
|
|
313
|
+
|
|
314
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
315
|
+
"""To emulate BM25 behaviour, we don't need to use weights in the query, and
|
|
316
|
+
it's enough to just hash the tokens and assign a weight of 1.0 to them.
|
|
317
|
+
"""
|
|
318
|
+
if isinstance(query, str):
|
|
319
|
+
query = [query]
|
|
320
|
+
|
|
321
|
+
for text in query:
|
|
322
|
+
text = remove_non_alphanumeric(text)
|
|
323
|
+
tokens = self.tokenizer.tokenize(text)
|
|
324
|
+
stemmed_tokens = self._stem(tokens)
|
|
325
|
+
token_ids = np.array(
|
|
326
|
+
list(set(self.compute_token_id(token) for token in stemmed_tokens)),
|
|
327
|
+
dtype=np.int32,
|
|
328
|
+
)
|
|
329
|
+
values = np.ones_like(token_ids)
|
|
330
|
+
yield SparseEmbedding(indices=token_ids, values=values)
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def _get_worker_class(cls) -> Type["Bm25Worker"]:
|
|
334
|
+
return Bm25Worker
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class Bm25Worker(Worker):
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
model_name: str,
|
|
341
|
+
cache_dir: str,
|
|
342
|
+
**kwargs: Any,
|
|
343
|
+
):
|
|
344
|
+
self.model = self.init_embedding(model_name, cache_dir, **kwargs)
|
|
345
|
+
|
|
346
|
+
@classmethod
|
|
347
|
+
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
|
|
348
|
+
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
|
|
349
|
+
|
|
350
|
+
def process(
|
|
351
|
+
self, items: Iterable[tuple[int, Any]]
|
|
352
|
+
) -> Iterable[tuple[int, list[SparseEmbedding]]]:
|
|
353
|
+
for idx, batch in items:
|
|
354
|
+
onnx_output = self.model.raw_embed(batch)
|
|
355
|
+
yield idx, onnx_output
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
def init_embedding(model_name: str, cache_dir: str, **kwargs: Any) -> Bm25:
|
|
359
|
+
return Bm25(model_name=model_name, cache_dir=cache_dir, **kwargs)
|