fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,369 @@
1
+ import math
2
+ import string
3
+ from pathlib import Path
4
+ from typing import Any, Iterable, Sequence, Type
5
+
6
+ import mmh3
7
+ import numpy as np
8
+ from py_rust_stemmers import SnowballStemmer
9
+
10
+ from fastembed.common import OnnxProvider
11
+ from fastembed.common.onnx_model import OnnxOutputContext
12
+ from fastembed.common.types import Device
13
+ from fastembed.common.utils import define_cache_dir
14
+ from fastembed.sparse.sparse_embedding_base import (
15
+ SparseEmbedding,
16
+ SparseTextEmbeddingBase,
17
+ )
18
+ from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
19
+ from fastembed.common.model_description import SparseModelDescription, ModelSource
20
+
21
+ supported_bm42_models: list[SparseModelDescription] = [
22
+ SparseModelDescription(
23
+ model="Qdrant/bm42-all-minilm-l6-v2-attentions",
24
+ vocab_size=30522,
25
+ description="Light sparse embedding model, which assigns an importance score to each token in the text",
26
+ license="apache-2.0",
27
+ size_in_GB=0.09,
28
+ sources=ModelSource(hf="Qdrant/all_miniLM_L6_v2_with_attentions"),
29
+ model_file="model.onnx",
30
+ additional_files=["stopwords.txt"],
31
+ requires_idf=True,
32
+ ),
33
+ ]
34
+
35
+
36
+ _MODEL_TO_LANGUAGE = {
37
+ "Qdrant/bm42-all-minilm-l6-v2-attentions": "english",
38
+ }
39
+ MODEL_TO_LANGUAGE = {
40
+ model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
41
+ }
42
+
43
+
44
+ def get_language_by_model_name(model_name: str) -> str:
45
+ return MODEL_TO_LANGUAGE[model_name.lower()]
46
+
47
+
48
+ class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
49
+ """
50
+ Bm42 is an extension of BM25, which tries to better evaluate importance of tokens in the documents,
51
+ by extracting attention weights from the transformer model.
52
+
53
+ Traditional BM25 uses a count of tokens in the document to evaluate the importance of the token,
54
+ but this approach doesn't work well with short documents or chunks of text, as almost all tokens
55
+ there are unique.
56
+
57
+ BM42 addresses this issue by replacing the token count with the attention weights from the transformer model.
58
+ This allows sparse embeddings to work well with short documents, handle rare tokens and leverage traditional NLP
59
+ techniques like stemming and stopwords.
60
+
61
+ WARNING: This model is expected to be used with `modifier="idf"` in the sparse vector index of Qdrant.
62
+ """
63
+
64
+ ONNX_OUTPUT_NAMES = ["attention_6"]
65
+
66
+ def __init__(
67
+ self,
68
+ model_name: str,
69
+ cache_dir: str | None = None,
70
+ threads: int | None = None,
71
+ providers: Sequence[OnnxProvider] | None = None,
72
+ alpha: float = 0.5,
73
+ cuda: bool | Device = Device.AUTO,
74
+ device_ids: list[int] | None = None,
75
+ lazy_load: bool = False,
76
+ device_id: int | None = None,
77
+ specific_model_path: str | None = None,
78
+ **kwargs: Any,
79
+ ):
80
+ """
81
+ Args:
82
+ model_name (str): The name of the model to use.
83
+ cache_dir (str, optional): The path to the cache directory.
84
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
85
+ Defaults to `fastembed_cache` in the system's temp directory.
86
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
87
+ providers (Optional[Sequence[OnnxProvider]], optional): The providers to use for onnxruntime.
88
+ alpha (float, optional): Parameter, that defines the importance of the token weight in the document
89
+ versus the importance of the token frequency in the corpus. Defaults to 0.5, based on empirical testing.
90
+ It is recommended to only change this parameter based on training data for a specific dataset.
91
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
92
+ Defaults to Device.AUTO.
93
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
94
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
95
+ with `providers`. Defaults to None.
96
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
97
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
98
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
99
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
100
+
101
+ Raises:
102
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
103
+ """
104
+
105
+ super().__init__(model_name, cache_dir, threads, **kwargs)
106
+ self.providers = providers
107
+ self.lazy_load = lazy_load
108
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
109
+
110
+ # List of device ids, that can be used for data parallel processing in workers
111
+ self.device_ids = device_ids
112
+ self.cuda = cuda
113
+
114
+ # This device_id will be used if we need to load model in current process
115
+ self.device_id: int | None = None
116
+ if device_id is not None:
117
+ self.device_id = device_id
118
+ elif self.device_ids is not None:
119
+ self.device_id = self.device_ids[0]
120
+
121
+ self.model_description = self._get_model_description(model_name)
122
+ self.cache_dir = str(define_cache_dir(cache_dir))
123
+
124
+ self._specific_model_path = specific_model_path
125
+ self._model_dir = self.download_model(
126
+ self.model_description,
127
+ self.cache_dir,
128
+ local_files_only=self._local_files_only,
129
+ specific_model_path=self._specific_model_path,
130
+ )
131
+
132
+ self.invert_vocab: dict[int, str] = {}
133
+
134
+ self.special_tokens: set[str] = set()
135
+ self.special_tokens_ids: set[int] = set()
136
+ self.punctuation = set(string.punctuation)
137
+ self.stopwords = set(self._load_stopwords(self._model_dir))
138
+ self.stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
139
+ self.alpha = alpha
140
+
141
+ if not self.lazy_load:
142
+ self.load_onnx_model()
143
+
144
+ def load_onnx_model(self) -> None:
145
+ self._load_onnx_model(
146
+ model_dir=self._model_dir,
147
+ model_file=self.model_description.model_file,
148
+ threads=self.threads,
149
+ providers=self.providers,
150
+ cuda=self.cuda,
151
+ device_id=self.device_id,
152
+ extra_session_options=self._extra_session_options,
153
+ )
154
+
155
+ for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr]
156
+ self.invert_vocab[idx] = token
157
+ self.special_tokens = set(self.special_token_to_id.keys())
158
+ self.special_tokens_ids = set(self.special_token_to_id.values())
159
+ self.stopwords = set(self._load_stopwords(self._model_dir))
160
+
161
+ def _filter_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
162
+ result: list[tuple[str, Any]] = []
163
+ for token, value in tokens:
164
+ if token in self.stopwords or token in self.punctuation:
165
+ continue
166
+ result.append((token, value))
167
+ return result
168
+
169
+ def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
170
+ result: list[tuple[str, Any]] = []
171
+ for token, value in tokens:
172
+ processed_token = self.stemmer.stem_word(token)
173
+ result.append((processed_token, value))
174
+ return result
175
+
176
+ @classmethod
177
+ def _aggregate_weights(
178
+ cls, tokens: list[tuple[str, list[int]]], weights: list[float]
179
+ ) -> list[tuple[str, float]]:
180
+ result: list[tuple[str, float]] = []
181
+ for token, idxs in tokens:
182
+ sum_weight = sum(weights[idx] for idx in idxs)
183
+ result.append((token, sum_weight))
184
+ return result
185
+
186
+ def _reconstruct_bpe(
187
+ self, bpe_tokens: Iterable[tuple[int, str]]
188
+ ) -> list[tuple[str, list[int]]]:
189
+ result: list[tuple[str, list[int]]] = []
190
+ acc: str = ""
191
+ acc_idx: list[int] = []
192
+
193
+ continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore[union-attr]
194
+ continuing_subword_prefix_len = len(continuing_subword_prefix)
195
+
196
+ for idx, token in bpe_tokens:
197
+ if token in self.special_tokens:
198
+ continue
199
+
200
+ if token.startswith(continuing_subword_prefix):
201
+ acc += token[continuing_subword_prefix_len:]
202
+ acc_idx.append(idx)
203
+ else:
204
+ if acc:
205
+ result.append((acc, acc_idx))
206
+ acc_idx = []
207
+ acc = token
208
+ acc_idx.append(idx)
209
+
210
+ if acc:
211
+ result.append((acc, acc_idx))
212
+
213
+ return result
214
+
215
+ def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
216
+ """
217
+ Orders all tokens in the vector by their importance and generates a new score based on the importance order.
218
+ So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
219
+ """
220
+
221
+ new_vector: dict[int, float] = {}
222
+
223
+ for token, value in vector.items():
224
+ token_id = abs(mmh3.hash(token))
225
+ # Examples:
226
+ # Num 0: Log(1/1 + 1) = 0.6931471805599453
227
+ # Num 1: Log(1/2 + 1) = 0.4054651081081644
228
+ # Num 2: Log(1/3 + 1) = 0.28768207245178085
229
+ new_vector[token_id] = math.log(1.0 + value) ** self.alpha # value
230
+
231
+ return new_vector
232
+
233
+ def _post_process_onnx_output(
234
+ self, output: OnnxOutputContext, **kwargs: Any
235
+ ) -> Iterable[SparseEmbedding]:
236
+ if output.input_ids is None:
237
+ raise ValueError("input_ids must be provided for document post-processing")
238
+
239
+ token_ids_batch = output.input_ids.astype(int)
240
+
241
+ # attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
242
+ pooled_attention = np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
243
+
244
+ for document_token_ids, attention_value in zip(token_ids_batch, pooled_attention):
245
+ document_tokens_with_ids = (
246
+ (idx, self.invert_vocab[token_id])
247
+ for idx, token_id in enumerate(document_token_ids)
248
+ )
249
+
250
+ reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
251
+
252
+ filtered = self._filter_pair_tokens(reconstructed)
253
+
254
+ stemmed = self._stem_pair_tokens(filtered)
255
+
256
+ weighted = self._aggregate_weights(stemmed, attention_value)
257
+
258
+ max_token_weight: dict[str, float] = {}
259
+
260
+ for token, weight in weighted:
261
+ max_token_weight[token] = max(max_token_weight.get(token, 0), weight)
262
+
263
+ rescored = self._rescore_vector(max_token_weight)
264
+
265
+ yield SparseEmbedding.from_dict(rescored)
266
+
267
+ @classmethod
268
+ def _list_supported_models(cls) -> list[SparseModelDescription]:
269
+ """Lists the supported models.
270
+
271
+ Returns:
272
+ list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
273
+ """
274
+ return supported_bm42_models
275
+
276
+ @classmethod
277
+ def _load_stopwords(cls, model_dir: Path) -> list[str]:
278
+ stopwords_path = model_dir / "stopwords.txt"
279
+ if not stopwords_path.exists():
280
+ return []
281
+
282
+ with open(stopwords_path, "r") as f:
283
+ return f.read().splitlines()
284
+
285
+ def embed(
286
+ self,
287
+ documents: str | Iterable[str],
288
+ batch_size: int = 256,
289
+ parallel: int | None = None,
290
+ **kwargs: Any,
291
+ ) -> Iterable[SparseEmbedding]:
292
+ """
293
+ Encode a list of documents into list of embeddings.
294
+ We use mean pooling with attention so that the model can handle variable-length inputs.
295
+
296
+ Args:
297
+ documents: Iterator of documents or single document to embed
298
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
299
+ parallel:
300
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
301
+ If 0, use all available cores.
302
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
303
+
304
+ Returns:
305
+ List of embeddings, one per document
306
+ """
307
+ yield from self._embed_documents(
308
+ model_name=self.model_name,
309
+ cache_dir=str(self.cache_dir),
310
+ documents=documents,
311
+ batch_size=batch_size,
312
+ parallel=parallel,
313
+ providers=self.providers,
314
+ cuda=self.cuda,
315
+ device_ids=self.device_ids,
316
+ alpha=self.alpha,
317
+ local_files_only=self._local_files_only,
318
+ specific_model_path=self._specific_model_path,
319
+ extra_session_options=self._extra_session_options,
320
+ )
321
+
322
+ @classmethod
323
+ def _query_rehash(cls, tokens: Iterable[str]) -> dict[int, float]:
324
+ result: dict[int, float] = {}
325
+ for token in tokens:
326
+ token_id = abs(mmh3.hash(token))
327
+ result[token_id] = 1.0
328
+ return result
329
+
330
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
331
+ """
332
+ To emulate BM25 behaviour, we don't need to use smart weights in the query, and
333
+ it's enough to just hash the tokens and assign a weight of 1.0 to them.
334
+ It is also faster, as we don't need to run the model for the query.
335
+ """
336
+ if isinstance(query, str):
337
+ query = [query]
338
+
339
+ if not hasattr(self, "model") or self.model is None:
340
+ self.load_onnx_model()
341
+
342
+ for text in query:
343
+ encoded = self.tokenizer.encode(text) # type: ignore[union-attr]
344
+ document_tokens_with_ids = enumerate(encoded.tokens)
345
+ reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
346
+ filtered = self._filter_pair_tokens(reconstructed)
347
+ stemmed = self._stem_pair_tokens(filtered)
348
+
349
+ yield SparseEmbedding.from_dict(self._query_rehash(token for token, _ in stemmed))
350
+
351
+ @classmethod
352
+ def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
353
+ return Bm42TextEmbeddingWorker
354
+
355
+ def token_count(
356
+ self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
357
+ ) -> int:
358
+ if not hasattr(self, "model") or self.model is None:
359
+ self.load_onnx_model() # loads the tokenizer as well
360
+ return self._token_count(texts, batch_size=batch_size, **kwargs)
361
+
362
+
363
+ class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
364
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42:
365
+ return Bm42(
366
+ model_name=model_name,
367
+ cache_dir=cache_dir,
368
+ **kwargs,
369
+ )