fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,164 @@
1
+ from typing import Any, Iterable, Type
2
+
3
+
4
+ from fastembed.common.types import NumpyArray
5
+ from fastembed.common.onnx_model import OnnxOutputContext
6
+ from fastembed.common.utils import normalize
7
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
8
+ from fastembed.text.pooled_embedding import PooledEmbedding
9
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
10
+
11
+ supported_pooled_normalized_models: list[DenseModelDescription] = [
12
+ DenseModelDescription(
13
+ model="sentence-transformers/all-MiniLM-L6-v2",
14
+ dim=384,
15
+ description=(
16
+ "Text embeddings, Unimodal (text), English, 256 input tokens truncation, "
17
+ "Prefixes for queries/documents: not necessary, 2021 year."
18
+ ),
19
+ license="apache-2.0",
20
+ size_in_GB=0.09,
21
+ sources=ModelSource(
22
+ url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
23
+ hf="qdrant/all-MiniLM-L6-v2-onnx",
24
+ _deprecated_tar_struct=True,
25
+ ),
26
+ model_file="model.onnx",
27
+ ),
28
+ DenseModelDescription(
29
+ model="jinaai/jina-embeddings-v2-base-en",
30
+ dim=768,
31
+ description=(
32
+ "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, "
33
+ "Prefixes for queries/documents: not necessary, 2023 year."
34
+ ),
35
+ license="apache-2.0",
36
+ size_in_GB=0.52,
37
+ sources=ModelSource(hf="xenova/jina-embeddings-v2-base-en"),
38
+ model_file="onnx/model.onnx",
39
+ ),
40
+ DenseModelDescription(
41
+ model="jinaai/jina-embeddings-v2-small-en",
42
+ dim=512,
43
+ description=(
44
+ "Text embeddings, Unimodal (text), English, 8192 input tokens truncation, "
45
+ "Prefixes for queries/documents: not necessary, 2023 year."
46
+ ),
47
+ license="apache-2.0",
48
+ size_in_GB=0.12,
49
+ sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
50
+ model_file="onnx/model.onnx",
51
+ ),
52
+ DenseModelDescription(
53
+ model="jinaai/jina-embeddings-v2-base-de",
54
+ dim=768,
55
+ description=(
56
+ "Text embeddings, Unimodal (text), Multilingual (German, English), 8192 input tokens truncation, "
57
+ "Prefixes for queries/documents: not necessary, 2024 year."
58
+ ),
59
+ license="apache-2.0",
60
+ size_in_GB=0.32,
61
+ sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-de"),
62
+ model_file="onnx/model_fp16.onnx",
63
+ ),
64
+ DenseModelDescription(
65
+ model="jinaai/jina-embeddings-v2-base-code",
66
+ dim=768,
67
+ description=(
68
+ "Text embeddings, Unimodal (text), Multilingual (English, 30 programming languages), "
69
+ "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
70
+ ),
71
+ license="apache-2.0",
72
+ size_in_GB=0.64,
73
+ sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-code"),
74
+ model_file="onnx/model.onnx",
75
+ ),
76
+ DenseModelDescription(
77
+ model="jinaai/jina-embeddings-v2-base-zh",
78
+ dim=768,
79
+ description=(
80
+ "Text embeddings, Unimodal (text), supports mixed Chinese-English input text, "
81
+ "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
82
+ ),
83
+ license="apache-2.0",
84
+ size_in_GB=0.64,
85
+ sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-zh"),
86
+ model_file="onnx/model.onnx",
87
+ ),
88
+ DenseModelDescription(
89
+ model="jinaai/jina-embeddings-v2-base-es",
90
+ dim=768,
91
+ description=(
92
+ "Text embeddings, Unimodal (text), supports mixed Spanish-English input text, "
93
+ "8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
94
+ ),
95
+ license="apache-2.0",
96
+ size_in_GB=0.64,
97
+ sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-es"),
98
+ model_file="onnx/model.onnx",
99
+ ),
100
+ DenseModelDescription(
101
+ model="thenlper/gte-base",
102
+ dim=768,
103
+ description=(
104
+ "General text embeddings, Unimodal (text), supports English only input text, "
105
+ "512 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
106
+ ),
107
+ license="mit",
108
+ size_in_GB=0.44,
109
+ sources=ModelSource(hf="thenlper/gte-base"),
110
+ model_file="onnx/model.onnx",
111
+ ),
112
+ DenseModelDescription(
113
+ model="thenlper/gte-large",
114
+ dim=1024,
115
+ description=(
116
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
117
+ "Prefixes for queries/documents: not necessary, 2023 year."
118
+ ),
119
+ license="mit",
120
+ size_in_GB=1.20,
121
+ sources=ModelSource(hf="qdrant/gte-large-onnx"),
122
+ model_file="model.onnx",
123
+ ),
124
+ ]
125
+
126
+
127
+ class PooledNormalizedEmbedding(PooledEmbedding):
128
+ @classmethod
129
+ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
130
+ return PooledNormalizedEmbeddingWorker
131
+
132
+ @classmethod
133
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
134
+ """Lists the supported models.
135
+
136
+ Returns:
137
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
138
+ """
139
+ return supported_pooled_normalized_models
140
+
141
+ def _post_process_onnx_output(
142
+ self, output: OnnxOutputContext, **kwargs: Any
143
+ ) -> Iterable[NumpyArray]:
144
+ if output.attention_mask is None:
145
+ raise ValueError("attention_mask must be provided for document post-processing")
146
+
147
+ embeddings = output.model_output
148
+ attn_mask = output.attention_mask
149
+ return normalize(self.mean_pooling(embeddings, attn_mask))
150
+
151
+
152
+ class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
153
+ def init_embedding(
154
+ self,
155
+ model_name: str,
156
+ cache_dir: str,
157
+ **kwargs: Any,
158
+ ) -> OnnxTextEmbedding:
159
+ return PooledNormalizedEmbedding(
160
+ model_name=model_name,
161
+ cache_dir=cache_dir,
162
+ threads=1,
163
+ **kwargs,
164
+ )
@@ -0,0 +1,228 @@
1
+ import warnings
2
+ from typing import Any, Iterable, Sequence, Type
3
+ from dataclasses import asdict
4
+
5
+ from fastembed.common.types import NumpyArray, OnnxProvider, Device
6
+ from fastembed.text.clip_embedding import CLIPOnnxEmbedding
7
+ from fastembed.text.custom_text_embedding import CustomTextEmbedding
8
+ from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
9
+ from fastembed.text.pooled_embedding import PooledEmbedding
10
+ from fastembed.text.multitask_embedding import JinaEmbeddingV3
11
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding
12
+ from fastembed.text.text_embedding_base import TextEmbeddingBase
13
+ from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
14
+
15
+
16
+ class TextEmbedding(TextEmbeddingBase):
17
+ EMBEDDINGS_REGISTRY: list[Type[TextEmbeddingBase]] = [
18
+ OnnxTextEmbedding,
19
+ CLIPOnnxEmbedding,
20
+ PooledNormalizedEmbedding,
21
+ PooledEmbedding,
22
+ JinaEmbeddingV3,
23
+ CustomTextEmbedding,
24
+ ]
25
+
26
+ @classmethod
27
+ def list_supported_models(cls) -> list[dict[str, Any]]:
28
+ """Lists the supported models.
29
+
30
+ Returns:
31
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
32
+ """
33
+ return [asdict(model) for model in cls._list_supported_models()]
34
+
35
+ @classmethod
36
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
37
+ result: list[DenseModelDescription] = []
38
+ for embedding in cls.EMBEDDINGS_REGISTRY:
39
+ result.extend(embedding._list_supported_models())
40
+ return result
41
+
42
+ @classmethod
43
+ def add_custom_model(
44
+ cls,
45
+ model: str,
46
+ pooling: PoolingType,
47
+ normalization: bool,
48
+ sources: ModelSource,
49
+ dim: int,
50
+ model_file: str = "onnx/model.onnx",
51
+ description: str = "",
52
+ license: str = "",
53
+ size_in_gb: float = 0.0,
54
+ additional_files: list[str] | None = None,
55
+ ) -> None:
56
+ registered_models = cls._list_supported_models()
57
+ for registered_model in registered_models:
58
+ if model.lower() == registered_model.model.lower():
59
+ raise ValueError(
60
+ f"Model {model} is already registered in TextEmbedding, if you still want to add this model, "
61
+ f"please use another model name"
62
+ )
63
+
64
+ CustomTextEmbedding.add_model(
65
+ DenseModelDescription(
66
+ model=model,
67
+ sources=sources,
68
+ dim=dim,
69
+ model_file=model_file,
70
+ description=description,
71
+ license=license,
72
+ size_in_GB=size_in_gb,
73
+ additional_files=additional_files or [],
74
+ ),
75
+ pooling=pooling,
76
+ normalization=normalization,
77
+ )
78
+
79
+ def __init__(
80
+ self,
81
+ model_name: str = "BAAI/bge-small-en-v1.5",
82
+ cache_dir: str | None = None,
83
+ threads: int | None = None,
84
+ providers: Sequence[OnnxProvider] | None = None,
85
+ cuda: bool | Device = Device.AUTO,
86
+ device_ids: list[int] | None = None,
87
+ lazy_load: bool = False,
88
+ **kwargs: Any,
89
+ ):
90
+ super().__init__(model_name, cache_dir, threads, **kwargs)
91
+ if model_name.lower() == "nomic-ai/nomic-embed-text-v1.5-Q".lower():
92
+ warnings.warn(
93
+ "The model 'nomic-ai/nomic-embed-text-v1.5-Q' has been updated on HuggingFace. Please review "
94
+ "the latest documentation on HF and release notes to ensure compatibility with your workflow. ",
95
+ UserWarning,
96
+ stacklevel=2,
97
+ )
98
+ if model_name.lower() in {
99
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".lower(),
100
+ "thenlper/gte-large".lower(),
101
+ "intfloat/multilingual-e5-large".lower(),
102
+ "sentence-transformers/paraphrase-multilingual-mpnet-base-v2".lower(),
103
+ }:
104
+ warnings.warn(
105
+ f"The model {model_name} now uses mean pooling instead of CLS embedding. "
106
+ f"In order to preserve the previous behaviour, consider either pinning fastembed version to 0.5.1 or "
107
+ "using `add_custom_model` functionality.",
108
+ UserWarning,
109
+ stacklevel=2,
110
+ )
111
+ for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
112
+ supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
113
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
114
+ self.model = EMBEDDING_MODEL_TYPE(
115
+ model_name=model_name,
116
+ cache_dir=cache_dir,
117
+ threads=threads,
118
+ providers=providers,
119
+ cuda=cuda,
120
+ device_ids=device_ids,
121
+ lazy_load=lazy_load,
122
+ **kwargs,
123
+ )
124
+ return
125
+
126
+ raise ValueError(
127
+ f"Model {model_name} is not supported in TextEmbedding. "
128
+ "Please check the supported models using `TextEmbedding.list_supported_models()`"
129
+ )
130
+
131
+ @property
132
+ def embedding_size(self) -> int:
133
+ """Get the embedding size of the current model"""
134
+ if self._embedding_size is None:
135
+ self._embedding_size = self.get_embedding_size(self.model_name)
136
+ return self._embedding_size
137
+
138
+ @classmethod
139
+ def get_embedding_size(cls, model_name: str) -> int:
140
+ """Get the embedding size of the passed model
141
+
142
+ Args:
143
+ model_name (str): The name of the model to get embedding size for.
144
+
145
+ Returns:
146
+ int: The size of the embedding.
147
+
148
+ Raises:
149
+ ValueError: If the model name is not found in the supported models.
150
+ """
151
+ descriptions = cls._list_supported_models()
152
+ embedding_size: int | None = None
153
+ for description in descriptions:
154
+ if description.model.lower() == model_name.lower():
155
+ embedding_size = description.dim
156
+ break
157
+ if embedding_size is None:
158
+ model_names = [description.model for description in descriptions]
159
+ raise ValueError(
160
+ f"Embedding size for model {model_name} was None. "
161
+ f"Available model names: {model_names}"
162
+ )
163
+ return embedding_size
164
+
165
+ def embed(
166
+ self,
167
+ documents: str | Iterable[str],
168
+ batch_size: int = 256,
169
+ parallel: int | None = None,
170
+ **kwargs: Any,
171
+ ) -> Iterable[NumpyArray]:
172
+ """
173
+ Encode a list of documents into list of embeddings.
174
+ We use mean pooling with attention so that the model can handle variable-length inputs.
175
+
176
+ Args:
177
+ documents: Iterator of documents or single document to embed
178
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
179
+ parallel:
180
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
181
+ If 0, use all available cores.
182
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
183
+
184
+ Returns:
185
+ List of embeddings, one per document
186
+ """
187
+ yield from self.model.embed(documents, batch_size, parallel, **kwargs)
188
+
189
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
190
+ """
191
+ Embeds queries
192
+
193
+ Args:
194
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
195
+
196
+ Returns:
197
+ Iterable[NumpyArray]: The embeddings.
198
+ """
199
+ # This is model-specific, so that different models can have specialized implementations
200
+ yield from self.model.query_embed(query, **kwargs)
201
+
202
+ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
203
+ """
204
+ Embeds a list of text passages into a list of embeddings.
205
+
206
+ Args:
207
+ texts (Iterable[str]): The list of texts to embed.
208
+ **kwargs: Additional keyword argument to pass to the embed method.
209
+
210
+ Yields:
211
+ Iterable[SparseEmbedding]: The sparse embeddings.
212
+ """
213
+ # This is model-specific, so that different models can have specialized implementations
214
+ yield from self.model.passage_embed(texts, **kwargs)
215
+
216
+ def token_count(
217
+ self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
218
+ ) -> int:
219
+ """Returns the number of tokens in the texts.
220
+
221
+ Args:
222
+ texts (str | Iterable[str]): The list of texts to embed.
223
+ batch_size (int): Batch size for encoding
224
+
225
+ Returns:
226
+ int: Sum of number of tokens in the texts.
227
+ """
228
+ return self.model.token_count(texts, batch_size=batch_size, **kwargs)
@@ -0,0 +1,75 @@
1
+ from typing import Iterable, Any
2
+
3
+ from fastembed.common.model_description import DenseModelDescription
4
+ from fastembed.common.types import NumpyArray
5
+ from fastembed.common.model_management import ModelManagement
6
+
7
+
8
+ class TextEmbeddingBase(ModelManagement[DenseModelDescription]):
9
+ def __init__(
10
+ self,
11
+ model_name: str,
12
+ cache_dir: str | None = None,
13
+ threads: int | None = None,
14
+ **kwargs: Any,
15
+ ):
16
+ self.model_name = model_name
17
+ self.cache_dir = cache_dir
18
+ self.threads = threads
19
+ self._local_files_only = kwargs.pop("local_files_only", False)
20
+ self._embedding_size: int | None = None
21
+
22
+ def embed(
23
+ self,
24
+ documents: str | Iterable[str],
25
+ batch_size: int = 256,
26
+ parallel: int | None = None,
27
+ **kwargs: Any,
28
+ ) -> Iterable[NumpyArray]:
29
+ raise NotImplementedError()
30
+
31
+ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
32
+ """
33
+ Embeds a list of text passages into a list of embeddings.
34
+
35
+ Args:
36
+ texts (Iterable[str]): The list of texts to embed.
37
+ **kwargs: Additional keyword argument to pass to the embed method.
38
+
39
+ Yields:
40
+ Iterable[NumpyArray]: The embeddings.
41
+ """
42
+
43
+ # This is model-specific, so that different models can have specialized implementations
44
+ yield from self.embed(texts, **kwargs)
45
+
46
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
47
+ """
48
+ Embeds queries
49
+
50
+ Args:
51
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
52
+
53
+ Returns:
54
+ Iterable[NumpyArray]: The embeddings.
55
+ """
56
+
57
+ # This is model-specific, so that different models can have specialized implementations
58
+ if isinstance(query, str):
59
+ yield from self.embed([query], **kwargs)
60
+ else:
61
+ yield from self.embed(query, **kwargs)
62
+
63
+ @classmethod
64
+ def get_embedding_size(cls, model_name: str) -> int:
65
+ """Returns embedding size of the passed model."""
66
+ raise NotImplementedError("Subclasses must implement this method")
67
+
68
+ @property
69
+ def embedding_size(self) -> int:
70
+ """Returns embedding size for the current model"""
71
+ raise NotImplementedError("Subclasses must implement this method")
72
+
73
+ def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
74
+ """Returns the number of tokens in the texts."""
75
+ raise NotImplementedError("Subclasses must implement this method")