fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,359 @@
1
+ import os
2
+ from collections import defaultdict
3
+ from multiprocessing import get_all_start_methods
4
+ from pathlib import Path
5
+ from typing import Any, Iterable, Type
6
+
7
+ import mmh3
8
+ import numpy as np
9
+ from py_rust_stemmers import SnowballStemmer
10
+ from fastembed.common.utils import (
11
+ define_cache_dir,
12
+ iter_batch,
13
+ get_all_punctuation,
14
+ remove_non_alphanumeric,
15
+ )
16
+ from fastembed.parallel_processor import ParallelWorkerPool, Worker
17
+ from fastembed.sparse.sparse_embedding_base import (
18
+ SparseEmbedding,
19
+ SparseTextEmbeddingBase,
20
+ )
21
+ from fastembed.sparse.utils.tokenizer import SimpleTokenizer
22
+ from fastembed.common.model_description import SparseModelDescription, ModelSource
23
+
24
+
25
+ supported_languages = [
26
+ "arabic",
27
+ "danish",
28
+ "dutch",
29
+ "english",
30
+ "finnish",
31
+ "french",
32
+ "german",
33
+ "greek",
34
+ "hungarian",
35
+ "italian",
36
+ "norwegian",
37
+ "portuguese",
38
+ "romanian",
39
+ "russian",
40
+ "spanish",
41
+ "swedish",
42
+ "tamil",
43
+ "turkish",
44
+ ]
45
+
46
+ supported_bm25_models: list[SparseModelDescription] = [
47
+ SparseModelDescription(
48
+ model="Qdrant/bm25",
49
+ vocab_size=0,
50
+ description="BM25 as sparse embeddings meant to be used with Qdrant",
51
+ license="apache-2.0",
52
+ size_in_GB=0.01,
53
+ sources=ModelSource(hf="Qdrant/bm25"),
54
+ additional_files=[f"{lang}.txt" for lang in supported_languages],
55
+ requires_idf=True,
56
+ model_file="mock.file",
57
+ ),
58
+ ]
59
+
60
+
61
+ class Bm25(SparseTextEmbeddingBase):
62
+ """Implements traditional BM25 in a form of sparse embeddings.
63
+ Uses a count of tokens in the document to evaluate the importance of the token.
64
+
65
+ WARNING: This model is expected to be used with `modifier="idf"` in the sparse vector index of Qdrant.
66
+
67
+ BM25 formula:
68
+
69
+ score(q, d) = SUM[ IDF(q_i) * (f(q_i, d) * (k + 1)) / (f(q_i, d) + k * (1 - b + b * (|d| / avg_len))) ],
70
+
71
+ where IDF is the inverse document frequency, computed on Qdrant's side
72
+ f(q_i, d) is the term frequency of the token q_i in the document d
73
+ k, b, avg_len are hyperparameters, described below.
74
+
75
+ Args:
76
+ model_name (str): The name of the model to use.
77
+ cache_dir (str, optional): The path to the cache directory.
78
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
79
+ Defaults to `fastembed_cache` in the system's temp directory.
80
+ k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
81
+ I.e. defines how fast the moment when additional terms stop to increase the score. Defaults to 1.2.
82
+ b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
83
+ Defaults to 0.75.
84
+ avg_len (float, optional): The average length of the documents in the corpus. Defaults to 256.0.
85
+ language (str): Specifies the language for the stemmer.
86
+ disable_stemmer (bool): Disable the stemmer.
87
+ Raises:
88
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ model_name: str,
94
+ cache_dir: str | None = None,
95
+ k: float = 1.2,
96
+ b: float = 0.75,
97
+ avg_len: float = 256.0,
98
+ language: str = "english",
99
+ token_max_length: int = 40,
100
+ disable_stemmer: bool = False,
101
+ specific_model_path: str | None = None,
102
+ **kwargs: Any,
103
+ ):
104
+ super().__init__(model_name, cache_dir, **kwargs)
105
+
106
+ if language not in supported_languages:
107
+ raise ValueError(f"{language} language is not supported")
108
+ else:
109
+ self.language = language
110
+
111
+ self.k = k
112
+ self.b = b
113
+ self.avg_len = avg_len
114
+
115
+ model_description = self._get_model_description(model_name)
116
+ self.cache_dir = str(define_cache_dir(cache_dir))
117
+
118
+ self._specific_model_path = specific_model_path
119
+ self._model_dir = self.download_model(
120
+ model_description,
121
+ self.cache_dir,
122
+ local_files_only=self._local_files_only,
123
+ specific_model_path=self._specific_model_path,
124
+ )
125
+
126
+ self.token_max_length = token_max_length
127
+ self.punctuation = set(get_all_punctuation())
128
+ self.disable_stemmer = disable_stemmer
129
+
130
+ if disable_stemmer:
131
+ self.stopwords: set[str] = set()
132
+ self.stemmer = None
133
+ else:
134
+ self.stopwords = set(self._load_stopwords(self._model_dir, self.language))
135
+ self.stemmer = SnowballStemmer(language)
136
+
137
+ self.tokenizer = SimpleTokenizer
138
+
139
+ @classmethod
140
+ def _list_supported_models(cls) -> list[SparseModelDescription]:
141
+ """Lists the supported models.
142
+
143
+ Returns:
144
+ list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
145
+ """
146
+ return supported_bm25_models
147
+
148
+ @classmethod
149
+ def _load_stopwords(cls, model_dir: Path, language: str) -> list[str]:
150
+ stopwords_path = model_dir / f"{language}.txt"
151
+ if not stopwords_path.exists():
152
+ return []
153
+
154
+ with open(stopwords_path, "r") as f:
155
+ return f.read().splitlines()
156
+
157
+ def _embed_documents(
158
+ self,
159
+ model_name: str,
160
+ cache_dir: str,
161
+ documents: str | Iterable[str],
162
+ batch_size: int = 256,
163
+ parallel: int | None = None,
164
+ local_files_only: bool = False,
165
+ specific_model_path: str | None = None,
166
+ ) -> Iterable[SparseEmbedding]:
167
+ is_small = False
168
+
169
+ if isinstance(documents, str):
170
+ documents = [documents]
171
+ is_small = True
172
+
173
+ if isinstance(documents, list):
174
+ if len(documents) < batch_size:
175
+ is_small = True
176
+
177
+ if parallel is None or is_small:
178
+ for batch in iter_batch(documents, batch_size):
179
+ yield from self.raw_embed(batch)
180
+ else:
181
+ if parallel == 0:
182
+ parallel = os.cpu_count()
183
+
184
+ start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
185
+ params = {
186
+ "model_name": model_name,
187
+ "cache_dir": cache_dir,
188
+ "k": self.k,
189
+ "b": self.b,
190
+ "avg_len": self.avg_len,
191
+ "language": self.language,
192
+ "token_max_length": self.token_max_length,
193
+ "disable_stemmer": self.disable_stemmer,
194
+ "local_files_only": local_files_only,
195
+ "specific_model_path": specific_model_path,
196
+ }
197
+ pool = ParallelWorkerPool(
198
+ num_workers=parallel or 1,
199
+ worker=self._get_worker_class(),
200
+ start_method=start_method,
201
+ )
202
+ for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
203
+ for record in batch:
204
+ yield record # type: ignore
205
+
206
+ def embed(
207
+ self,
208
+ documents: str | Iterable[str],
209
+ batch_size: int = 256,
210
+ parallel: int | None = None,
211
+ **kwargs: Any,
212
+ ) -> Iterable[SparseEmbedding]:
213
+ """
214
+ Encode a list of documents into list of embeddings.
215
+ We use mean pooling with attention so that the model can handle variable-length inputs.
216
+
217
+ Args:
218
+ documents: Iterator of documents or single document to embed
219
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
220
+ parallel:
221
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
222
+ If 0, use all available cores.
223
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
224
+
225
+ Returns:
226
+ List of embeddings, one per document
227
+ """
228
+ yield from self._embed_documents(
229
+ model_name=self.model_name,
230
+ cache_dir=str(self.cache_dir),
231
+ documents=documents,
232
+ batch_size=batch_size,
233
+ parallel=parallel,
234
+ local_files_only=self._local_files_only,
235
+ specific_model_path=self._specific_model_path,
236
+ )
237
+
238
+ def _stem(self, tokens: list[str]) -> list[str]:
239
+ stemmed_tokens: list[str] = []
240
+ for token in tokens:
241
+ lower_token = token.lower()
242
+
243
+ if token in self.punctuation:
244
+ continue
245
+
246
+ if lower_token in self.stopwords:
247
+ continue
248
+
249
+ if len(token) > self.token_max_length:
250
+ continue
251
+
252
+ stemmed_token = self.stemmer.stem_word(lower_token) if self.stemmer else lower_token
253
+
254
+ if stemmed_token:
255
+ stemmed_tokens.append(stemmed_token)
256
+ return stemmed_tokens
257
+
258
+ def raw_embed(
259
+ self,
260
+ documents: list[str],
261
+ ) -> list[SparseEmbedding]:
262
+ embeddings: list[SparseEmbedding] = []
263
+ for document in documents:
264
+ document = remove_non_alphanumeric(document)
265
+ tokens = self.tokenizer.tokenize(document)
266
+ stemmed_tokens = self._stem(tokens)
267
+ token_id2value = self._term_frequency(stemmed_tokens)
268
+ embeddings.append(SparseEmbedding.from_dict(token_id2value))
269
+ return embeddings
270
+
271
+ def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
272
+ token_num = 0
273
+ texts = [texts] if isinstance(texts, str) else texts
274
+ for text in texts:
275
+ document = remove_non_alphanumeric(text)
276
+ tokens = self.tokenizer.tokenize(document)
277
+ token_num += len(tokens)
278
+ return token_num
279
+
280
+ def _term_frequency(self, tokens: list[str]) -> dict[int, float]:
281
+ """Calculate the term frequency part of the BM25 formula.
282
+
283
+ (
284
+ f(q_i, d) * (k + 1)
285
+ ) / (
286
+ f(q_i, d) + k * (1 - b + b * (|d| / avg_len))
287
+ )
288
+
289
+ Args:
290
+ tokens (list[str]): The list of tokens in the document.
291
+
292
+ Returns:
293
+ dict[int, float]: The token_id to term frequency mapping.
294
+ """
295
+ tf_map: dict[int, float] = {}
296
+ counter: defaultdict[str, int] = defaultdict(int)
297
+ for stemmed_token in tokens:
298
+ counter[stemmed_token] += 1
299
+
300
+ doc_len = len(tokens)
301
+ for stemmed_token in counter:
302
+ token_id = self.compute_token_id(stemmed_token)
303
+ num_occurrences = counter[stemmed_token]
304
+ tf_map[token_id] = num_occurrences * (self.k + 1)
305
+ tf_map[token_id] /= num_occurrences + self.k * (
306
+ 1 - self.b + self.b * doc_len / self.avg_len
307
+ )
308
+ return tf_map
309
+
310
+ @classmethod
311
+ def compute_token_id(cls, token: str) -> int:
312
+ return abs(mmh3.hash(token))
313
+
314
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
315
+ """To emulate BM25 behaviour, we don't need to use weights in the query, and
316
+ it's enough to just hash the tokens and assign a weight of 1.0 to them.
317
+ """
318
+ if isinstance(query, str):
319
+ query = [query]
320
+
321
+ for text in query:
322
+ text = remove_non_alphanumeric(text)
323
+ tokens = self.tokenizer.tokenize(text)
324
+ stemmed_tokens = self._stem(tokens)
325
+ token_ids = np.array(
326
+ list(set(self.compute_token_id(token) for token in stemmed_tokens)),
327
+ dtype=np.int32,
328
+ )
329
+ values = np.ones_like(token_ids)
330
+ yield SparseEmbedding(indices=token_ids, values=values)
331
+
332
+ @classmethod
333
+ def _get_worker_class(cls) -> Type["Bm25Worker"]:
334
+ return Bm25Worker
335
+
336
+
337
+ class Bm25Worker(Worker):
338
+ def __init__(
339
+ self,
340
+ model_name: str,
341
+ cache_dir: str,
342
+ **kwargs: Any,
343
+ ):
344
+ self.model = self.init_embedding(model_name, cache_dir, **kwargs)
345
+
346
+ @classmethod
347
+ def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "Bm25Worker":
348
+ return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
349
+
350
+ def process(
351
+ self, items: Iterable[tuple[int, Any]]
352
+ ) -> Iterable[tuple[int, list[SparseEmbedding]]]:
353
+ for idx, batch in items:
354
+ onnx_output = self.model.raw_embed(batch)
355
+ yield idx, onnx_output
356
+
357
+ @staticmethod
358
+ def init_embedding(model_name: str, cache_dir: str, **kwargs: Any) -> Bm25:
359
+ return Bm25(model_name=model_name, cache_dir=cache_dir, **kwargs)