fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,204 @@
1
+ import os
2
+ from multiprocessing import get_all_start_methods
3
+ from pathlib import Path
4
+ from typing import Any, Iterable, Sequence, Type
5
+
6
+ import numpy as np
7
+ from tokenizers import Encoding
8
+
9
+ from fastembed.common.onnx_model import (
10
+ EmbeddingWorker,
11
+ OnnxModel,
12
+ OnnxOutputContext,
13
+ OnnxProvider,
14
+ )
15
+ from fastembed.common.types import NumpyArray, Device
16
+ from fastembed.common.preprocessor_utils import load_tokenizer
17
+ from fastembed.common.utils import iter_batch
18
+ from fastembed.parallel_processor import ParallelWorkerPool
19
+
20
+
21
+ class OnnxCrossEncoderModel(OnnxModel[float]):
22
+ ONNX_OUTPUT_NAMES: list[str] | None = None
23
+
24
+ @classmethod
25
+ def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
26
+ raise NotImplementedError("Subclasses must implement this method")
27
+
28
+ def _load_onnx_model(
29
+ self,
30
+ model_dir: Path,
31
+ model_file: str,
32
+ threads: int | None,
33
+ providers: Sequence[OnnxProvider] | None = None,
34
+ cuda: bool | Device = Device.AUTO,
35
+ device_id: int | None = None,
36
+ extra_session_options: dict[str, Any] | None = None,
37
+ ) -> None:
38
+ super()._load_onnx_model(
39
+ model_dir=model_dir,
40
+ model_file=model_file,
41
+ threads=threads,
42
+ providers=providers,
43
+ cuda=cuda,
44
+ device_id=device_id,
45
+ extra_session_options=extra_session_options,
46
+ )
47
+ self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
48
+ assert self.tokenizer is not None
49
+
50
+ def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
51
+ return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr]
52
+
53
+ def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:
54
+ input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
55
+ inputs: dict[str, NumpyArray] = {
56
+ "input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
57
+ }
58
+ if "token_type_ids" in input_names:
59
+ inputs["token_type_ids"] = np.array(
60
+ [enc.type_ids for enc in tokenized_input], dtype=np.int64
61
+ )
62
+ if "attention_mask" in input_names:
63
+ inputs["attention_mask"] = np.array(
64
+ [enc.attention_mask for enc in tokenized_input], dtype=np.int64
65
+ )
66
+ return inputs
67
+
68
+ def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
69
+ pairs = [(query, doc) for doc in documents]
70
+ return self.onnx_embed_pairs(pairs, **kwargs)
71
+
72
+ def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
73
+ tokenized_input = self.tokenize(pairs, **kwargs)
74
+ inputs = self._build_onnx_input(tokenized_input)
75
+ onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
76
+ outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
77
+ relevant_output = outputs[0]
78
+ scores: NumpyArray = relevant_output[:, 0]
79
+ return OnnxOutputContext(model_output=scores)
80
+
81
+ def _rerank_documents(
82
+ self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
83
+ ) -> Iterable[float]:
84
+ if not hasattr(self, "model") or self.model is None:
85
+ self.load_onnx_model()
86
+ for batch in iter_batch(documents, batch_size):
87
+ yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))
88
+
89
+ def _rerank_pairs(
90
+ self,
91
+ model_name: str,
92
+ cache_dir: str,
93
+ pairs: Iterable[tuple[str, str]],
94
+ batch_size: int,
95
+ parallel: int | None = None,
96
+ providers: Sequence[OnnxProvider] | None = None,
97
+ cuda: bool | Device = Device.AUTO,
98
+ device_ids: list[int] | None = None,
99
+ local_files_only: bool = False,
100
+ specific_model_path: str | None = None,
101
+ extra_session_options: dict[str, Any] | None = None,
102
+ **kwargs: Any,
103
+ ) -> Iterable[float]:
104
+ is_small = False
105
+
106
+ if isinstance(pairs, tuple):
107
+ pairs = [pairs]
108
+ is_small = True
109
+
110
+ if isinstance(pairs, list):
111
+ if len(pairs) < batch_size:
112
+ is_small = True
113
+
114
+ if parallel is None or is_small:
115
+ if not hasattr(self, "model") or self.model is None:
116
+ self.load_onnx_model()
117
+ for batch in iter_batch(pairs, batch_size):
118
+ yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
119
+ else:
120
+ if parallel == 0:
121
+ parallel = os.cpu_count()
122
+
123
+ start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
124
+ params = {
125
+ "model_name": model_name,
126
+ "cache_dir": cache_dir,
127
+ "providers": providers,
128
+ "local_files_only": local_files_only,
129
+ "specific_model_path": specific_model_path,
130
+ **kwargs,
131
+ }
132
+
133
+ if extra_session_options is not None:
134
+ params.update(extra_session_options)
135
+
136
+ pool = ParallelWorkerPool(
137
+ num_workers=parallel or 1,
138
+ worker=self._get_worker_class(),
139
+ cuda=cuda,
140
+ device_ids=device_ids,
141
+ start_method=start_method,
142
+ )
143
+ for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
144
+ yield from self._post_process_onnx_output(batch) # type: ignore
145
+
146
+ def _post_process_onnx_output(
147
+ self, output: OnnxOutputContext, **kwargs: Any
148
+ ) -> Iterable[float]:
149
+ """Post-process the ONNX model output to convert it into a usable format.
150
+
151
+ Args:
152
+ output (OnnxOutputContext): The raw output from the ONNX model.
153
+ **kwargs: Additional keyword arguments that may be needed by specific implementations.
154
+
155
+ Returns:
156
+ Iterable[float]: Post-processed output as an iterable of float values.
157
+ """
158
+ raise NotImplementedError("Subclasses must implement this method")
159
+
160
+ def _preprocess_onnx_input(
161
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
162
+ ) -> dict[str, NumpyArray]:
163
+ """
164
+ Preprocess the onnx input.
165
+ """
166
+ return onnx_input
167
+
168
+ def _token_count(
169
+ self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any
170
+ ) -> int:
171
+ if not hasattr(self, "model") or self.model is None:
172
+ self.load_onnx_model() # loads the tokenizer as well
173
+
174
+ token_num = 0
175
+ assert self.tokenizer is not None
176
+ for batch in iter_batch(pairs, batch_size):
177
+ for tokens in self.tokenizer.encode_batch(batch):
178
+ token_num += sum(tokens.attention_mask)
179
+
180
+ return token_num
181
+
182
+
183
+ class TextRerankerWorker(EmbeddingWorker[float]):
184
+ def __init__(
185
+ self,
186
+ model_name: str,
187
+ cache_dir: str,
188
+ **kwargs: Any,
189
+ ):
190
+ self.model: OnnxCrossEncoderModel
191
+ super().__init__(model_name, cache_dir, **kwargs)
192
+
193
+ def init_embedding(
194
+ self,
195
+ model_name: str,
196
+ cache_dir: str,
197
+ **kwargs: Any,
198
+ ) -> OnnxCrossEncoderModel:
199
+ raise NotImplementedError()
200
+
201
+ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
202
+ for idx, batch in items:
203
+ onnx_output = self.model.onnx_embed_pairs(batch)
204
+ yield idx, onnx_output
@@ -0,0 +1,178 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+ from dataclasses import asdict
3
+
4
+ from fastembed.common import OnnxProvider
5
+ from fastembed.common.types import Device
6
+ from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
7
+ from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder
8
+
9
+ from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
10
+ from fastembed.common.model_description import (
11
+ ModelSource,
12
+ BaseModelDescription,
13
+ )
14
+
15
+
16
+ class TextCrossEncoder(TextCrossEncoderBase):
17
+ CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [
18
+ OnnxTextCrossEncoder,
19
+ CustomTextCrossEncoder,
20
+ ]
21
+
22
+ @classmethod
23
+ def list_supported_models(cls) -> list[dict[str, Any]]:
24
+ """Lists the supported models.
25
+
26
+ Returns:
27
+ list[BaseModelDescription]: A list of dictionaries containing the model information.
28
+
29
+ Example:
30
+ ```
31
+ [
32
+ {
33
+ "model": "Xenova/ms-marco-MiniLM-L-6-v2",
34
+ "size_in_GB": 0.08,
35
+ "sources": {
36
+ "hf": "Xenova/ms-marco-MiniLM-L-6-v2",
37
+ },
38
+ "model_file": "onnx/model.onnx",
39
+ "description": "MiniLM-L-6-v2 model optimized for re-ranking tasks.",
40
+ "license": "apache-2.0",
41
+ }
42
+ ]
43
+ ```
44
+ """
45
+ return [asdict(model) for model in cls._list_supported_models()]
46
+
47
+ @classmethod
48
+ def _list_supported_models(cls) -> list[BaseModelDescription]:
49
+ result: list[BaseModelDescription] = []
50
+ for encoder in cls.CROSS_ENCODER_REGISTRY:
51
+ result.extend(encoder._list_supported_models())
52
+ return result
53
+
54
+ def __init__(
55
+ self,
56
+ model_name: str,
57
+ cache_dir: str | None = None,
58
+ threads: int | None = None,
59
+ providers: Sequence[OnnxProvider] | None = None,
60
+ cuda: bool | Device = Device.AUTO,
61
+ device_ids: list[int] | None = None,
62
+ lazy_load: bool = False,
63
+ **kwargs: Any,
64
+ ):
65
+ super().__init__(model_name, cache_dir, threads, **kwargs)
66
+
67
+ for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY:
68
+ supported_models = CROSS_ENCODER_TYPE._list_supported_models()
69
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
70
+ self.model = CROSS_ENCODER_TYPE(
71
+ model_name=model_name,
72
+ cache_dir=cache_dir,
73
+ threads=threads,
74
+ providers=providers,
75
+ cuda=cuda,
76
+ device_ids=device_ids,
77
+ lazy_load=lazy_load,
78
+ **kwargs,
79
+ )
80
+ return
81
+
82
+ raise ValueError(
83
+ f"Model {model_name} is not supported in TextCrossEncoder."
84
+ "Please check the supported models using `TextCrossEncoder.list_supported_models()`"
85
+ )
86
+
87
+ def rerank(
88
+ self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any
89
+ ) -> Iterable[float]:
90
+ """Rerank a list of documents based on a query.
91
+
92
+ Args:
93
+ query: Query to rerank the documents against
94
+ documents: Iterator of documents to rerank
95
+ batch_size: Batch size for reranking
96
+
97
+ Returns:
98
+ Iterable of scores for each document
99
+ """
100
+ yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs)
101
+
102
+ def rerank_pairs(
103
+ self,
104
+ pairs: Iterable[tuple[str, str]],
105
+ batch_size: int = 64,
106
+ parallel: int | None = None,
107
+ **kwargs: Any,
108
+ ) -> Iterable[float]:
109
+ """
110
+ Rerank a list of query-document pairs.
111
+
112
+ Args:
113
+ pairs (Iterable[tuple[str, str]]): An iterable of tuples, where each tuple contains a query and a document
114
+ to be scored together.
115
+ batch_size (int, optional): The number of query-document pairs to process in a single batch. Defaults to 64.
116
+ parallel (Optional[int], optional): The number of parallel processes to use for reranking.
117
+ If None, parallelization is disabled. Defaults to None.
118
+ **kwargs (Any): Additional arguments to pass to the underlying reranking model.
119
+
120
+ Returns:
121
+ Iterable[float]: An iterable of scores corresponding to each query-document pair in the input.
122
+ Higher scores indicate a stronger match between the query and the document.
123
+
124
+ Example:
125
+ >>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2")
126
+ >>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")]
127
+ >>> scores = list(encoder.rerank_pairs(pairs))
128
+ >>> print(list(map(lambda x: round(x, 2), scores)))
129
+ [-1.24, -10.6]
130
+ """
131
+ yield from self.model.rerank_pairs(
132
+ pairs, batch_size=batch_size, parallel=parallel, **kwargs
133
+ )
134
+
135
+ @classmethod
136
+ def add_custom_model(
137
+ cls,
138
+ model: str,
139
+ sources: ModelSource,
140
+ model_file: str = "onnx/model.onnx",
141
+ description: str = "",
142
+ license: str = "",
143
+ size_in_gb: float = 0.0,
144
+ additional_files: list[str] | None = None,
145
+ ) -> None:
146
+ registered_models = cls._list_supported_models()
147
+ for registered_model in registered_models:
148
+ if model == registered_model.model:
149
+ raise ValueError(
150
+ f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, "
151
+ f"please use another model name"
152
+ )
153
+
154
+ CustomTextCrossEncoder.add_model(
155
+ BaseModelDescription(
156
+ model=model,
157
+ sources=sources,
158
+ model_file=model_file,
159
+ description=description,
160
+ license=license,
161
+ size_in_GB=size_in_gb,
162
+ additional_files=additional_files or [],
163
+ )
164
+ )
165
+
166
+ def token_count(
167
+ self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
168
+ ) -> int:
169
+ """Returns the number of tokens in the pairs.
170
+
171
+ Args:
172
+ pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
173
+ batch_size: Batch size for tokenizing
174
+
175
+ Returns:
176
+ token count: overall number of tokens in the pairs
177
+ """
178
+ return self.model.token_count(pairs, batch_size=batch_size, **kwargs)
@@ -0,0 +1,63 @@
1
+ from typing import Any, Iterable
2
+
3
+ from fastembed.common.model_description import BaseModelDescription
4
+ from fastembed.common.model_management import ModelManagement
5
+
6
+
7
+ class TextCrossEncoderBase(ModelManagement[BaseModelDescription]):
8
+ def __init__(
9
+ self,
10
+ model_name: str,
11
+ cache_dir: str | None = None,
12
+ threads: int | None = None,
13
+ **kwargs: Any,
14
+ ):
15
+ self.model_name = model_name
16
+ self.cache_dir = cache_dir
17
+ self.threads = threads
18
+ self._local_files_only = kwargs.pop("local_files_only", False)
19
+
20
+ def rerank(
21
+ self,
22
+ query: str,
23
+ documents: Iterable[str],
24
+ batch_size: int = 64,
25
+ **kwargs: Any,
26
+ ) -> Iterable[float]:
27
+ """Rerank a list of documents given a query.
28
+
29
+ Args:
30
+ query (str): The query to rerank the documents.
31
+ documents (Iterable[str]): The list of texts to rerank.
32
+ batch_size (int): The batch size to use for reranking.
33
+ **kwargs: Additional keyword argument to pass to the rerank method.
34
+
35
+ Yields:
36
+ Iterable[float]: The scores of the reranked the documents.
37
+ """
38
+ raise NotImplementedError("This method should be overridden by subclasses")
39
+
40
+ def rerank_pairs(
41
+ self,
42
+ pairs: Iterable[tuple[str, str]],
43
+ batch_size: int = 64,
44
+ parallel: int | None = None,
45
+ **kwargs: Any,
46
+ ) -> Iterable[float]:
47
+ """Rerank query-document pairs.
48
+ Args:
49
+ pairs (Iterable[tuple[str, str]]): Query-document pairs to rerank
50
+ batch_size (int): The batch size to use for reranking.
51
+ parallel: parallel:
52
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
53
+ If 0, use all available cores.
54
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
55
+ **kwargs: Additional keyword argument to pass to the rerank method.
56
+ Yields:
57
+ Iterable[float]: Scores for each individual pair
58
+ """
59
+ raise NotImplementedError("This method should be overridden by subclasses")
60
+
61
+ def token_count(self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> int:
62
+ """Returns the number of tokens in the pairs."""
63
+ raise NotImplementedError("This method should be overridden by subclasses")
@@ -0,0 +1,4 @@
1
+ from fastembed.sparse.sparse_embedding_base import SparseEmbedding
2
+ from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
3
+
4
+ __all__ = ["SparseEmbedding", "SparseTextEmbedding"]