fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,353 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+
3
+ from fastembed.common.types import NumpyArray, OnnxProvider, Device
4
+ from fastembed.common.onnx_model import OnnxOutputContext
5
+ from fastembed.common.utils import define_cache_dir, normalize
6
+ from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
7
+ from fastembed.text.text_embedding_base import TextEmbeddingBase
8
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
9
+
10
+ supported_onnx_models: list[DenseModelDescription] = [
11
+ DenseModelDescription(
12
+ model="BAAI/bge-base-en",
13
+ dim=768,
14
+ description=(
15
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
16
+ "Prefixes for queries/documents: necessary, 2023 year."
17
+ ),
18
+ license="mit",
19
+ size_in_GB=0.42,
20
+ sources=ModelSource(
21
+ hf="Qdrant/fast-bge-base-en",
22
+ url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
23
+ _deprecated_tar_struct=True,
24
+ ),
25
+ model_file="model_optimized.onnx",
26
+ ),
27
+ DenseModelDescription(
28
+ model="BAAI/bge-base-en-v1.5",
29
+ dim=768,
30
+ description=(
31
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
32
+ "Prefixes for queries/documents: not so necessary, 2023 year."
33
+ ),
34
+ license="mit",
35
+ size_in_GB=0.21,
36
+ sources=ModelSource(
37
+ hf="qdrant/bge-base-en-v1.5-onnx-q",
38
+ url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
39
+ _deprecated_tar_struct=True,
40
+ ),
41
+ model_file="model_optimized.onnx",
42
+ ),
43
+ DenseModelDescription(
44
+ model="BAAI/bge-large-en-v1.5",
45
+ dim=1024,
46
+ description=(
47
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
48
+ "Prefixes for queries/documents: not so necessary, 2023 year."
49
+ ),
50
+ license="mit",
51
+ size_in_GB=1.20,
52
+ sources=ModelSource(hf="qdrant/bge-large-en-v1.5-onnx"),
53
+ model_file="model.onnx",
54
+ ),
55
+ DenseModelDescription(
56
+ model="BAAI/bge-small-en",
57
+ dim=384,
58
+ description=(
59
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
60
+ "Prefixes for queries/documents: necessary, 2023 year."
61
+ ),
62
+ license="mit",
63
+ size_in_GB=0.13,
64
+ sources=ModelSource(
65
+ hf="Qdrant/bge-small-en",
66
+ url="https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
67
+ _deprecated_tar_struct=True,
68
+ ),
69
+ model_file="model_optimized.onnx",
70
+ ),
71
+ DenseModelDescription(
72
+ model="BAAI/bge-small-en-v1.5",
73
+ dim=384,
74
+ description=(
75
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
76
+ "Prefixes for queries/documents: not so necessary, 2023 year."
77
+ ),
78
+ license="mit",
79
+ size_in_GB=0.067,
80
+ sources=ModelSource(hf="qdrant/bge-small-en-v1.5-onnx-q"),
81
+ model_file="model_optimized.onnx",
82
+ ),
83
+ DenseModelDescription(
84
+ model="BAAI/bge-small-zh-v1.5",
85
+ dim=512,
86
+ description=(
87
+ "Text embeddings, Unimodal (text), Chinese, 512 input tokens truncation, "
88
+ "Prefixes for queries/documents: not so necessary, 2023 year."
89
+ ),
90
+ license="mit",
91
+ size_in_GB=0.09,
92
+ sources=ModelSource(
93
+ hf="Qdrant/bge-small-zh-v1.5",
94
+ url="https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
95
+ _deprecated_tar_struct=True,
96
+ ),
97
+ model_file="model_optimized.onnx",
98
+ ),
99
+ DenseModelDescription(
100
+ model="mixedbread-ai/mxbai-embed-large-v1",
101
+ dim=1024,
102
+ description=(
103
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
104
+ "Prefixes for queries/documents: necessary, 2024 year."
105
+ ),
106
+ license="apache-2.0",
107
+ size_in_GB=0.64,
108
+ sources=ModelSource(hf="mixedbread-ai/mxbai-embed-large-v1"),
109
+ model_file="onnx/model.onnx",
110
+ ),
111
+ DenseModelDescription(
112
+ model="snowflake/snowflake-arctic-embed-xs",
113
+ dim=384,
114
+ description=(
115
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
116
+ "Prefixes for queries/documents: necessary, 2024 year."
117
+ ),
118
+ license="apache-2.0",
119
+ size_in_GB=0.09,
120
+ sources=ModelSource(hf="snowflake/snowflake-arctic-embed-xs"),
121
+ model_file="onnx/model.onnx",
122
+ ),
123
+ DenseModelDescription(
124
+ model="snowflake/snowflake-arctic-embed-s",
125
+ dim=384,
126
+ description=(
127
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
128
+ "Prefixes for queries/documents: necessary, 2024 year."
129
+ ),
130
+ license="apache-2.0",
131
+ size_in_GB=0.13,
132
+ sources=ModelSource(hf="snowflake/snowflake-arctic-embed-s"),
133
+ model_file="onnx/model.onnx",
134
+ ),
135
+ DenseModelDescription(
136
+ model="snowflake/snowflake-arctic-embed-m",
137
+ dim=768,
138
+ description=(
139
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
140
+ "Prefixes for queries/documents: necessary, 2024 year."
141
+ ),
142
+ license="apache-2.0",
143
+ size_in_GB=0.43,
144
+ sources=ModelSource(hf="Snowflake/snowflake-arctic-embed-m"),
145
+ model_file="onnx/model.onnx",
146
+ ),
147
+ DenseModelDescription(
148
+ model="snowflake/snowflake-arctic-embed-m-long",
149
+ dim=768,
150
+ description=(
151
+ "Text embeddings, Unimodal (text), English, 2048 input tokens truncation, "
152
+ "Prefixes for queries/documents: necessary, 2024 year."
153
+ ),
154
+ license="apache-2.0",
155
+ size_in_GB=0.54,
156
+ sources=ModelSource(hf="snowflake/snowflake-arctic-embed-m-long"),
157
+ model_file="onnx/model.onnx",
158
+ ),
159
+ DenseModelDescription(
160
+ model="snowflake/snowflake-arctic-embed-l",
161
+ dim=1024,
162
+ description=(
163
+ "Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
164
+ "Prefixes for queries/documents: necessary, 2024 year."
165
+ ),
166
+ license="apache-2.0",
167
+ size_in_GB=1.02,
168
+ sources=ModelSource(hf="snowflake/snowflake-arctic-embed-l"),
169
+ model_file="onnx/model.onnx",
170
+ ),
171
+ DenseModelDescription(
172
+ model="jinaai/jina-clip-v1",
173
+ dim=768,
174
+ description=(
175
+ "Text embeddings, Multimodal (text&image), English, Prefixes for queries/documents: "
176
+ "not necessary, 2024 year"
177
+ ),
178
+ license="apache-2.0",
179
+ size_in_GB=0.55,
180
+ sources=ModelSource(hf="jinaai/jina-clip-v1"),
181
+ model_file="onnx/text_model.onnx",
182
+ ),
183
+ ]
184
+
185
+
186
+ class OnnxTextEmbedding(TextEmbeddingBase, OnnxTextModel[NumpyArray]):
187
+ """Implementation of the Flag Embedding model."""
188
+
189
+ @classmethod
190
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
191
+ """
192
+ Lists the supported models.
193
+
194
+ Returns:
195
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
196
+ """
197
+ return supported_onnx_models
198
+
199
+ def __init__(
200
+ self,
201
+ model_name: str = "BAAI/bge-small-en-v1.5",
202
+ cache_dir: str | None = None,
203
+ threads: int | None = None,
204
+ providers: Sequence[OnnxProvider] | None = None,
205
+ cuda: bool | Device = Device.AUTO,
206
+ device_ids: list[int] | None = None,
207
+ lazy_load: bool = False,
208
+ device_id: int | None = None,
209
+ specific_model_path: str | None = None,
210
+ **kwargs: Any,
211
+ ):
212
+ """
213
+ Args:
214
+ model_name (str): The name of the model to use.
215
+ cache_dir (str, optional): The path to the cache directory.
216
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
217
+ Defaults to `fastembed_cache` in the system's temp directory.
218
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
219
+ providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
220
+ Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
221
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
222
+ Defaults to Device.AUTO.
223
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
224
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
225
+ with `providers`. Defaults to None.
226
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
227
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
228
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
229
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
230
+
231
+ Raises:
232
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
233
+ """
234
+ super().__init__(model_name, cache_dir, threads, **kwargs)
235
+ self.providers = providers
236
+ self.lazy_load = lazy_load
237
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
238
+ # List of device ids, that can be used for data parallel processing in workers
239
+ self.device_ids = device_ids
240
+ self.cuda = cuda
241
+
242
+ # This device_id will be used if we need to load model in current process
243
+ self.device_id: int | None = None
244
+ if device_id is not None:
245
+ self.device_id = device_id
246
+ elif self.device_ids is not None:
247
+ self.device_id = self.device_ids[0]
248
+
249
+ self.model_description = self._get_model_description(model_name)
250
+ self.cache_dir = str(define_cache_dir(cache_dir))
251
+ self._specific_model_path = specific_model_path
252
+ self._model_dir = self.download_model(
253
+ self.model_description,
254
+ self.cache_dir,
255
+ local_files_only=self._local_files_only,
256
+ specific_model_path=self._specific_model_path,
257
+ )
258
+
259
+ if not self.lazy_load:
260
+ self.load_onnx_model()
261
+
262
+ def embed(
263
+ self,
264
+ documents: str | Iterable[str],
265
+ batch_size: int = 256,
266
+ parallel: int | None = None,
267
+ **kwargs: Any,
268
+ ) -> Iterable[NumpyArray]:
269
+ """
270
+ Encode a list of documents into list of embeddings.
271
+ We use mean pooling with attention so that the model can handle variable-length inputs.
272
+
273
+ Args:
274
+ documents: Iterator of documents or single document to embed
275
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
276
+ parallel:
277
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
278
+ If 0, use all available cores.
279
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
280
+
281
+ Returns:
282
+ List of embeddings, one per document
283
+ """
284
+ yield from self._embed_documents(
285
+ model_name=self.model_name,
286
+ cache_dir=str(self.cache_dir),
287
+ documents=documents,
288
+ batch_size=batch_size,
289
+ parallel=parallel,
290
+ providers=self.providers,
291
+ cuda=self.cuda,
292
+ device_ids=self.device_ids,
293
+ local_files_only=self._local_files_only,
294
+ specific_model_path=self._specific_model_path,
295
+ extra_session_options=self._extra_session_options,
296
+ **kwargs,
297
+ )
298
+
299
+ @classmethod
300
+ def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]:
301
+ return OnnxTextEmbeddingWorker
302
+
303
+ def _preprocess_onnx_input(
304
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
305
+ ) -> dict[str, NumpyArray]:
306
+ """
307
+ Preprocess the onnx input.
308
+ """
309
+ return onnx_input
310
+
311
+ def _post_process_onnx_output(
312
+ self, output: OnnxOutputContext, **kwargs: Any
313
+ ) -> Iterable[NumpyArray]:
314
+ embeddings = output.model_output
315
+
316
+ if embeddings.ndim == 3: # (batch_size, seq_len, embedding_dim)
317
+ processed_embeddings = embeddings[:, 0]
318
+ elif embeddings.ndim == 2: # (batch_size, embedding_dim)
319
+ processed_embeddings = embeddings
320
+ else:
321
+ raise ValueError(f"Unsupported embedding shape: {embeddings.shape}")
322
+ return normalize(processed_embeddings)
323
+
324
+ def load_onnx_model(self) -> None:
325
+ self._load_onnx_model(
326
+ model_dir=self._model_dir,
327
+ model_file=self.model_description.model_file,
328
+ threads=self.threads,
329
+ providers=self.providers,
330
+ cuda=self.cuda,
331
+ device_id=self.device_id,
332
+ extra_session_options=self._extra_session_options,
333
+ )
334
+
335
+ def token_count(
336
+ self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
337
+ ) -> int:
338
+ return self._token_count(texts, batch_size=batch_size, **kwargs)
339
+
340
+
341
+ class OnnxTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
342
+ def init_embedding(
343
+ self,
344
+ model_name: str,
345
+ cache_dir: str,
346
+ **kwargs: Any,
347
+ ) -> OnnxTextEmbedding:
348
+ return OnnxTextEmbedding(
349
+ model_name=model_name,
350
+ cache_dir=cache_dir,
351
+ threads=1,
352
+ **kwargs,
353
+ )
@@ -0,0 +1,180 @@
1
+ import os
2
+ from multiprocessing import get_all_start_methods
3
+ from pathlib import Path
4
+ from typing import Any, Iterable, Sequence, Type
5
+
6
+ import numpy as np
7
+ from numpy.typing import NDArray
8
+ from tokenizers import Encoding, Tokenizer
9
+
10
+ from fastembed.common.types import NumpyArray, OnnxProvider, Device
11
+ from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
12
+ from fastembed.common.preprocessor_utils import load_tokenizer
13
+ from fastembed.common.utils import iter_batch
14
+ from fastembed.parallel_processor import ParallelWorkerPool
15
+
16
+
17
+ class OnnxTextModel(OnnxModel[T]):
18
+ ONNX_OUTPUT_NAMES: list[str] | None = None
19
+
20
+ @classmethod
21
+ def _get_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
22
+ raise NotImplementedError("Subclasses must implement this method")
23
+
24
+ def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
25
+ """Post-process the ONNX model output to convert it into a usable format.
26
+
27
+ Args:
28
+ output (OnnxOutputContext): The raw output from the ONNX model.
29
+ **kwargs: Additional keyword arguments that may be needed by specific implementations.
30
+
31
+ Returns:
32
+ Iterable[T]: Post-processed output as an iterable of type T.
33
+ """
34
+ raise NotImplementedError("Subclasses must implement this method")
35
+
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+ self.tokenizer: Tokenizer | None = None
39
+ self.special_token_to_id: dict[str, int] = {}
40
+
41
+ def _preprocess_onnx_input(
42
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
43
+ ) -> dict[str, NumpyArray | NDArray[np.int64]]:
44
+ """
45
+ Preprocess the onnx input.
46
+ """
47
+ return onnx_input
48
+
49
+ def _load_onnx_model(
50
+ self,
51
+ model_dir: Path,
52
+ model_file: str,
53
+ threads: int | None,
54
+ providers: Sequence[OnnxProvider] | None = None,
55
+ cuda: bool | Device = Device.AUTO,
56
+ device_id: int | None = None,
57
+ extra_session_options: dict[str, Any] | None = None,
58
+ ) -> None:
59
+ super()._load_onnx_model(
60
+ model_dir=model_dir,
61
+ model_file=model_file,
62
+ threads=threads,
63
+ providers=providers,
64
+ cuda=cuda,
65
+ device_id=device_id,
66
+ extra_session_options=extra_session_options,
67
+ )
68
+ self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
69
+
70
+ def load_onnx_model(self) -> None:
71
+ raise NotImplementedError("Subclasses must implement this method")
72
+
73
+ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
74
+ return self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
75
+
76
+ def onnx_embed(
77
+ self,
78
+ documents: list[str],
79
+ **kwargs: Any,
80
+ ) -> OnnxOutputContext:
81
+ encoded = self.tokenize(documents, **kwargs)
82
+ input_ids = np.array([e.ids for e in encoded])
83
+ attention_mask = np.array([e.attention_mask for e in encoded])
84
+ input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
85
+ onnx_input: dict[str, NumpyArray] = {
86
+ "input_ids": np.array(input_ids, dtype=np.int64),
87
+ }
88
+ if "attention_mask" in input_names:
89
+ onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64)
90
+ if "token_type_ids" in input_names:
91
+ onnx_input["token_type_ids"] = np.array(
92
+ [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
93
+ )
94
+ onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs)
95
+
96
+ model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
97
+ return OnnxOutputContext(
98
+ model_output=model_output[0],
99
+ attention_mask=onnx_input.get("attention_mask", attention_mask),
100
+ input_ids=onnx_input.get("input_ids", input_ids),
101
+ )
102
+
103
+ def _embed_documents(
104
+ self,
105
+ model_name: str,
106
+ cache_dir: str,
107
+ documents: str | Iterable[str],
108
+ batch_size: int = 256,
109
+ parallel: int | None = None,
110
+ providers: Sequence[OnnxProvider] | None = None,
111
+ cuda: bool | Device = Device.AUTO,
112
+ device_ids: list[int] | None = None,
113
+ local_files_only: bool = False,
114
+ specific_model_path: str | None = None,
115
+ extra_session_options: dict[str, Any] | None = None,
116
+ **kwargs: Any,
117
+ ) -> Iterable[T]:
118
+ is_small = False
119
+
120
+ if isinstance(documents, str):
121
+ documents = [documents]
122
+ is_small = True
123
+
124
+ if isinstance(documents, list):
125
+ if len(documents) < batch_size:
126
+ is_small = True
127
+
128
+ if parallel is None or is_small:
129
+ if not hasattr(self, "model") or self.model is None:
130
+ self.load_onnx_model()
131
+ for batch in iter_batch(documents, batch_size):
132
+ yield from self._post_process_onnx_output(
133
+ self.onnx_embed(batch, **kwargs), **kwargs
134
+ )
135
+ else:
136
+ if parallel == 0:
137
+ parallel = os.cpu_count()
138
+
139
+ start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
140
+ params = {
141
+ "model_name": model_name,
142
+ "cache_dir": cache_dir,
143
+ "providers": providers,
144
+ "local_files_only": local_files_only,
145
+ "specific_model_path": specific_model_path,
146
+ **kwargs,
147
+ }
148
+
149
+ if extra_session_options is not None:
150
+ params.update(extra_session_options)
151
+
152
+ pool = ParallelWorkerPool(
153
+ num_workers=parallel or 1,
154
+ worker=self._get_worker_class(),
155
+ cuda=cuda,
156
+ device_ids=device_ids,
157
+ start_method=start_method,
158
+ )
159
+ for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
160
+ yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore
161
+
162
+ def _token_count(self, texts: str | Iterable[str], batch_size: int = 1024, **_: Any) -> int:
163
+ if not hasattr(self, "model") or self.model is None:
164
+ self.load_onnx_model() # loads the tokenizer as well
165
+
166
+ token_num = 0
167
+ assert self.tokenizer is not None
168
+ texts = [texts] if isinstance(texts, str) else texts
169
+ for batch in iter_batch(texts, batch_size):
170
+ for tokens in self.tokenizer.encode_batch(batch):
171
+ token_num += sum(tokens.attention_mask)
172
+
173
+ return token_num
174
+
175
+
176
+ class TextEmbeddingWorker(EmbeddingWorker[T]):
177
+ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
178
+ for idx, batch in items:
179
+ onnx_output = self.model.onnx_embed(batch)
180
+ yield idx, onnx_output
@@ -0,0 +1,136 @@
1
+ from typing import Any, Iterable, Type
2
+
3
+ import numpy as np
4
+ from numpy.typing import NDArray
5
+
6
+ from fastembed.common.types import NumpyArray
7
+ from fastembed.common.onnx_model import OnnxOutputContext
8
+ from fastembed.common.utils import mean_pooling
9
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
10
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
11
+
12
+ supported_pooled_models: list[DenseModelDescription] = [
13
+ DenseModelDescription(
14
+ model="nomic-ai/nomic-embed-text-v1.5",
15
+ dim=768,
16
+ description=(
17
+ "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
18
+ "Prefixes for queries/documents: necessary, 2024 year."
19
+ ),
20
+ license="apache-2.0",
21
+ size_in_GB=0.52,
22
+ sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"),
23
+ model_file="onnx/model.onnx",
24
+ ),
25
+ DenseModelDescription(
26
+ model="nomic-ai/nomic-embed-text-v1.5-Q",
27
+ dim=768,
28
+ description=(
29
+ "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
30
+ "Prefixes for queries/documents: necessary, 2024 year."
31
+ ),
32
+ license="apache-2.0",
33
+ size_in_GB=0.13,
34
+ sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1.5"),
35
+ model_file="onnx/model_quantized.onnx",
36
+ ),
37
+ DenseModelDescription(
38
+ model="nomic-ai/nomic-embed-text-v1",
39
+ dim=768,
40
+ description=(
41
+ "Text embeddings, Multimodal (text, image), English, 8192 input tokens truncation, "
42
+ "Prefixes for queries/documents: necessary, 2024 year."
43
+ ),
44
+ license="apache-2.0",
45
+ size_in_GB=0.52,
46
+ sources=ModelSource(hf="nomic-ai/nomic-embed-text-v1"),
47
+ model_file="onnx/model.onnx",
48
+ ),
49
+ DenseModelDescription(
50
+ model="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
51
+ dim=384,
52
+ description=(
53
+ "Text embeddings, Unimodal (text), Multilingual (~50 languages), 512 input tokens truncation, "
54
+ "Prefixes for queries/documents: not necessary, 2019 year."
55
+ ),
56
+ license="apache-2.0",
57
+ size_in_GB=0.22,
58
+ sources=ModelSource(hf="qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
59
+ model_file="model_optimized.onnx",
60
+ ),
61
+ DenseModelDescription(
62
+ model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
63
+ dim=768,
64
+ description=(
65
+ "Text embeddings, Unimodal (text), Multilingual (~50 languages), 384 input tokens truncation, "
66
+ "Prefixes for queries/documents: not necessary, 2021 year."
67
+ ),
68
+ license="apache-2.0",
69
+ size_in_GB=1.00,
70
+ sources=ModelSource(hf="xenova/paraphrase-multilingual-mpnet-base-v2"),
71
+ model_file="onnx/model.onnx",
72
+ ),
73
+ DenseModelDescription(
74
+ model="intfloat/multilingual-e5-large",
75
+ dim=1024,
76
+ description=(
77
+ "Text embeddings, Unimodal (text), Multilingual (~100 languages), 512 input tokens truncation, "
78
+ "Prefixes for queries/documents: necessary, 2024 year."
79
+ ),
80
+ license="mit",
81
+ size_in_GB=2.24,
82
+ sources=ModelSource(
83
+ hf="qdrant/multilingual-e5-large-onnx",
84
+ url="https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
85
+ _deprecated_tar_struct=True,
86
+ ),
87
+ model_file="model.onnx",
88
+ additional_files=["model.onnx_data"],
89
+ ),
90
+ ]
91
+
92
+
93
+ class PooledEmbedding(OnnxTextEmbedding):
94
+ @classmethod
95
+ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
96
+ return PooledEmbeddingWorker
97
+
98
+ @classmethod
99
+ def mean_pooling(
100
+ cls, model_output: NumpyArray, attention_mask: NDArray[np.int64]
101
+ ) -> NumpyArray:
102
+ return mean_pooling(model_output, attention_mask)
103
+
104
+ @classmethod
105
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
106
+ """Lists the supported models.
107
+
108
+ Returns:
109
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
110
+ """
111
+ return supported_pooled_models
112
+
113
+ def _post_process_onnx_output(
114
+ self, output: OnnxOutputContext, **kwargs: Any
115
+ ) -> Iterable[NumpyArray]:
116
+ if output.attention_mask is None:
117
+ raise ValueError("attention_mask must be provided for document post-processing")
118
+
119
+ embeddings = output.model_output
120
+ attn_mask = output.attention_mask
121
+ return self.mean_pooling(embeddings, attn_mask)
122
+
123
+
124
+ class PooledEmbeddingWorker(OnnxTextEmbeddingWorker):
125
+ def init_embedding(
126
+ self,
127
+ model_name: str,
128
+ cache_dir: str,
129
+ **kwargs: Any,
130
+ ) -> OnnxTextEmbedding:
131
+ return PooledEmbedding(
132
+ model_name=model_name,
133
+ cache_dir=cache_dir,
134
+ threads=1,
135
+ **kwargs,
136
+ )