fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,372 @@
1
+ from pathlib import Path
2
+
3
+ from typing import Any, Sequence, Iterable, Type
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from py_rust_stemmers import SnowballStemmer
8
+ from tokenizers import Tokenizer
9
+
10
+ from fastembed.common.model_description import SparseModelDescription, ModelSource
11
+ from fastembed.common.onnx_model import OnnxOutputContext
12
+ from fastembed.common import OnnxProvider
13
+ from fastembed.common.types import Device
14
+ from fastembed.common.utils import define_cache_dir
15
+ from fastembed.sparse.sparse_embedding_base import (
16
+ SparseEmbedding,
17
+ SparseTextEmbeddingBase,
18
+ )
19
+ from fastembed.sparse.utils.minicoil_encoder import Encoder
20
+ from fastembed.sparse.utils.sparse_vectors_converter import SparseVectorConverter, WordEmbedding
21
+ from fastembed.sparse.utils.vocab_resolver import VocabResolver, VocabTokenizer
22
+ from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
23
+
24
+
25
+ MINICOIL_MODEL_FILE = "minicoil.triplet.model.npy"
26
+ MINICOIL_VOCAB_FILE = "minicoil.triplet.model.vocab"
27
+ STOPWORDS_FILE = "stopwords.txt"
28
+
29
+
30
+ supported_minicoil_models: list[SparseModelDescription] = [
31
+ SparseModelDescription(
32
+ model="Qdrant/minicoil-v1",
33
+ vocab_size=19125,
34
+ description="Sparse embedding model, that resolves semantic meaning of the words, "
35
+ "while keeping exact keyword match behavior. "
36
+ "Based on jinaai/jina-embeddings-v2-small-en-tokens",
37
+ license="apache-2.0",
38
+ size_in_GB=0.09,
39
+ sources=ModelSource(hf="Qdrant/minicoil-v1"),
40
+ model_file="onnx/model.onnx",
41
+ additional_files=[
42
+ STOPWORDS_FILE,
43
+ MINICOIL_MODEL_FILE,
44
+ MINICOIL_VOCAB_FILE,
45
+ ],
46
+ requires_idf=True,
47
+ ),
48
+ ]
49
+
50
+ _MODEL_TO_LANGUAGE = {
51
+ "Qdrant/minicoil-v1": "english",
52
+ }
53
+ MODEL_TO_LANGUAGE = {
54
+ model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
55
+ }
56
+
57
+
58
+ def get_language_by_model_name(model_name: str) -> str:
59
+ return MODEL_TO_LANGUAGE[model_name.lower()]
60
+
61
+
62
+ class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
63
+ """
64
+ MiniCOIL is a sparse embedding model, that resolves semantic meaning of the words,
65
+ while keeping exact keyword match behavior.
66
+
67
+ Each vocabulary token is converted into 4d component of a sparse vector, which is then weighted by the token frequency in the corpus.
68
+ If the token is not found in the corpus, it is treated exactly like in BM25.
69
+ `
70
+ The model is based on `jinaai/jina-embeddings-v2-small-en-tokens`
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model_name: str,
76
+ cache_dir: str | None = None,
77
+ threads: int | None = None,
78
+ providers: Sequence[OnnxProvider] | None = None,
79
+ k: float = 1.2,
80
+ b: float = 0.75,
81
+ avg_len: float = 150.0,
82
+ cuda: bool | Device = Device.AUTO,
83
+ device_ids: list[int] | None = None,
84
+ lazy_load: bool = False,
85
+ device_id: int | None = None,
86
+ specific_model_path: str | None = None,
87
+ **kwargs: Any,
88
+ ):
89
+ """
90
+ Args:
91
+ model_name (str): The name of the model to use.
92
+ cache_dir (str, optional): The path to the cache directory.
93
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
94
+ Defaults to `fastembed_cache` in the system's temp directory.
95
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
96
+ providers (Optional[Sequence[OnnxProvider]], optional): The providers to use for onnxruntime.
97
+ k (float, optional): The k parameter in the BM25 formula. Defines the saturation of the term frequency.
98
+ I.e. defines how fast the moment when additional terms stop to increase the score. Defaults to 1.2.
99
+ b (float, optional): The b parameter in the BM25 formula. Defines the importance of the document length.
100
+ Defaults to 0.75.
101
+ avg_len (float, optional): The average length of the documents in the corpus. Defaults to 150.0.
102
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
103
+ Defaults to Device.AUTO.
104
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
105
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
106
+ with `providers`. Defaults to None.
107
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
108
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
109
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
110
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
111
+
112
+ Raises:
113
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
114
+ """
115
+
116
+ super().__init__(model_name, cache_dir, threads, **kwargs)
117
+ self.providers = providers
118
+ self.lazy_load = lazy_load
119
+ self.device_ids = device_ids
120
+ self.cuda = cuda
121
+ self.device_id = device_id
122
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
123
+
124
+ self.k = k
125
+ self.b = b
126
+ self.avg_len = avg_len
127
+
128
+ # Initialize class attributes
129
+ self.tokenizer: Tokenizer | None = None
130
+ self.invert_vocab: dict[int, str] = {}
131
+ self.special_tokens: set[str] = set()
132
+ self.special_tokens_ids: set[int] = set()
133
+ self.stopwords: set[str] = set()
134
+ self.vocab_resolver: VocabResolver | None = None
135
+ self.encoder: Encoder | None = None
136
+ self.output_dim: int | None = None
137
+ self.sparse_vector_converter: SparseVectorConverter | None = None
138
+
139
+ self.model_description = self._get_model_description(model_name)
140
+ self.cache_dir = str(define_cache_dir(cache_dir))
141
+ self._specific_model_path = specific_model_path
142
+ self._model_dir = self.download_model(
143
+ self.model_description,
144
+ self.cache_dir,
145
+ local_files_only=self._local_files_only,
146
+ specific_model_path=self._specific_model_path,
147
+ )
148
+
149
+ if not self.lazy_load:
150
+ self.load_onnx_model()
151
+
152
+ def load_onnx_model(self) -> None:
153
+ self._load_onnx_model(
154
+ model_dir=self._model_dir,
155
+ model_file=self.model_description.model_file,
156
+ threads=self.threads,
157
+ providers=self.providers,
158
+ cuda=self.cuda,
159
+ device_id=self.device_id,
160
+ extra_session_options=self._extra_session_options,
161
+ )
162
+
163
+ assert self.tokenizer is not None
164
+
165
+ for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr]
166
+ self.invert_vocab[idx] = token
167
+ self.special_tokens = set(self.special_token_to_id.keys())
168
+ self.special_tokens_ids = set(self.special_token_to_id.values())
169
+ self.stopwords = set(self._load_stopwords(self._model_dir))
170
+
171
+ stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
172
+
173
+ self.vocab_resolver = VocabResolver(
174
+ tokenizer=VocabTokenizer(self.tokenizer),
175
+ stopwords=self.stopwords,
176
+ stemmer=stemmer,
177
+ )
178
+ self.vocab_resolver.load_json_vocab(str(self._model_dir / MINICOIL_VOCAB_FILE))
179
+
180
+ weights = np.load(str(self._model_dir / MINICOIL_MODEL_FILE), mmap_mode="r")
181
+ self.encoder = Encoder(weights)
182
+ self.output_dim = self.encoder.output_dim
183
+
184
+ self.sparse_vector_converter = SparseVectorConverter(
185
+ stopwords=self.stopwords,
186
+ stemmer=stemmer,
187
+ k=self.k,
188
+ b=self.b,
189
+ avg_len=self.avg_len,
190
+ )
191
+
192
+ def token_count(
193
+ self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
194
+ ) -> int:
195
+ return self._token_count(texts, batch_size=batch_size, **kwargs)
196
+
197
+ def embed(
198
+ self,
199
+ documents: str | Iterable[str],
200
+ batch_size: int = 256,
201
+ parallel: int | None = None,
202
+ **kwargs: Any,
203
+ ) -> Iterable[SparseEmbedding]:
204
+ """
205
+ Encode a list of documents into list of embeddings.
206
+ We use mean pooling with attention so that the model can handle variable-length inputs.
207
+
208
+ Args:
209
+ documents: Iterator of documents or single document to embed
210
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
211
+ parallel:
212
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
213
+ If 0, use all available cores.
214
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
215
+
216
+ Returns:
217
+ List of embeddings, one per document
218
+ """
219
+ yield from self._embed_documents(
220
+ model_name=self.model_name,
221
+ cache_dir=str(self.cache_dir),
222
+ documents=documents,
223
+ batch_size=batch_size,
224
+ parallel=parallel,
225
+ providers=self.providers,
226
+ cuda=self.cuda,
227
+ device_ids=self.device_ids,
228
+ k=self.k,
229
+ b=self.b,
230
+ avg_len=self.avg_len,
231
+ is_query=False,
232
+ local_files_only=self._local_files_only,
233
+ specific_model_path=self._specific_model_path,
234
+ extra_session_options=self._extra_session_options,
235
+ **kwargs,
236
+ )
237
+
238
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
239
+ """
240
+ Encode a list of queries into list of embeddings.
241
+ """
242
+ yield from self._embed_documents(
243
+ model_name=self.model_name,
244
+ cache_dir=str(self.cache_dir),
245
+ documents=query,
246
+ providers=self.providers,
247
+ cuda=self.cuda,
248
+ device_ids=self.device_ids,
249
+ k=self.k,
250
+ b=self.b,
251
+ avg_len=self.avg_len,
252
+ is_query=True,
253
+ local_files_only=self._local_files_only,
254
+ specific_model_path=self._specific_model_path,
255
+ **kwargs,
256
+ )
257
+
258
+ @classmethod
259
+ def _load_stopwords(cls, model_dir: Path) -> list[str]:
260
+ stopwords_path = model_dir / STOPWORDS_FILE
261
+ if not stopwords_path.exists():
262
+ return []
263
+
264
+ with open(stopwords_path, "r") as f:
265
+ return f.read().splitlines()
266
+
267
+ @classmethod
268
+ def _list_supported_models(cls) -> list[SparseModelDescription]:
269
+ """Lists the supported models.
270
+
271
+ Returns:
272
+ list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
273
+ """
274
+ return supported_minicoil_models
275
+
276
+ def _post_process_onnx_output(
277
+ self, output: OnnxOutputContext, is_query: bool = False, **kwargs: Any
278
+ ) -> Iterable[SparseEmbedding]:
279
+ if output.input_ids is None:
280
+ raise ValueError("input_ids must be provided for document post-processing")
281
+
282
+ assert self.vocab_resolver is not None
283
+ assert self.encoder is not None
284
+ assert self.sparse_vector_converter is not None
285
+
286
+ # Size: (batch_size, sequence_length, hidden_size)
287
+ embeddings = output.model_output
288
+ # Size: (batch_size, sequence_length)
289
+ assert output.attention_mask is not None
290
+ masks = output.attention_mask
291
+
292
+ vocab_size = self.vocab_resolver.vocab_size()
293
+ embedding_size = self.encoder.output_dim
294
+
295
+ # For each document we only select those embeddings that are not masked out
296
+
297
+ for i in range(embeddings.shape[0]):
298
+ # Size: (sequence_length, hidden_size)
299
+ token_embeddings = embeddings[i, masks[i] == 1]
300
+
301
+ # Size: (sequence_length)
302
+ token_ids: NDArray[np.int64] = output.input_ids[i, masks[i] == 1]
303
+
304
+ word_ids_array, counts, oov, forms = self.vocab_resolver.resolve_tokens(token_ids)
305
+
306
+ # Size: (1, words)
307
+ word_ids_array_expanded: NDArray[np.int64] = np.expand_dims(word_ids_array, axis=0)
308
+
309
+ # Size: (1, words, embedding_size)
310
+ token_embeddings_array: NDArray[np.float32] = np.expand_dims(token_embeddings, axis=0)
311
+
312
+ assert word_ids_array_expanded.shape[1] == token_embeddings_array.shape[1]
313
+
314
+ # Size of word_ids_mapping: (unique_words, 2) - [vocab_id, batch_id]
315
+ # Size of embeddings: (unique_words, embedding_size)
316
+ ids_mapping, minicoil_embeddings = self.encoder.forward(
317
+ word_ids_array_expanded, token_embeddings_array
318
+ )
319
+
320
+ # Size of counts: (unique_words)
321
+ words_ids: list[int] = ids_mapping[:, 0].tolist() # type: ignore[assignment]
322
+
323
+ sentence_result: dict[str, WordEmbedding] = {}
324
+
325
+ words = [self.vocab_resolver.lookup_word(word_id) for word_id in words_ids]
326
+
327
+ for word, word_id, emb in zip(words, words_ids, minicoil_embeddings.tolist()): # type: ignore[arg-type]
328
+ if word_id == 0:
329
+ continue
330
+
331
+ sentence_result[word] = WordEmbedding(
332
+ word=word,
333
+ forms=forms[word],
334
+ count=int(counts[word_id]),
335
+ word_id=int(word_id),
336
+ embedding=emb, # type: ignore[arg-type]
337
+ )
338
+
339
+ for oov_word, count in oov.items():
340
+ # {
341
+ # "word": oov_word,
342
+ # "forms": [oov_word],
343
+ # "count": int(count),
344
+ # "word_id": -1,
345
+ # "embedding": [1]
346
+ # }
347
+ sentence_result[oov_word] = WordEmbedding(
348
+ word=oov_word, forms=[oov_word], count=int(count), word_id=-1, embedding=[1]
349
+ )
350
+
351
+ if not is_query:
352
+ yield self.sparse_vector_converter.embedding_to_vector(
353
+ sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
354
+ )
355
+ else:
356
+ yield self.sparse_vector_converter.embedding_to_vector_query(
357
+ sentence_result, vocab_size=vocab_size, embedding_size=embedding_size
358
+ )
359
+
360
+ @classmethod
361
+ def _get_worker_class(cls) -> Type["MiniCoilTextEmbeddingWorker"]:
362
+ return MiniCoilTextEmbeddingWorker
363
+
364
+
365
+ class MiniCoilTextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
366
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> MiniCOIL:
367
+ return MiniCOIL(
368
+ model_name=model_name,
369
+ cache_dir=cache_dir,
370
+ threads=1,
371
+ **kwargs,
372
+ )
@@ -0,0 +1,90 @@
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, Any
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from fastembed.common.model_description import SparseModelDescription
8
+ from fastembed.common.types import NumpyArray
9
+ from fastembed.common.model_management import ModelManagement
10
+
11
+
12
+ @dataclass
13
+ class SparseEmbedding:
14
+ values: NumpyArray
15
+ indices: NDArray[np.int64] | NDArray[np.int32]
16
+
17
+ def as_object(self) -> dict[str, NumpyArray]:
18
+ return {
19
+ "values": self.values,
20
+ "indices": self.indices,
21
+ }
22
+
23
+ def as_dict(self) -> dict[int, float]:
24
+ return {int(i): float(v) for i, v in zip(self.indices, self.values)} # type: ignore
25
+
26
+ @classmethod
27
+ def from_dict(cls, data: dict[int, float]) -> "SparseEmbedding":
28
+ if len(data) == 0:
29
+ return cls(values=np.array([]), indices=np.array([]))
30
+ indices, values = zip(*data.items())
31
+ return cls(values=np.array(values), indices=np.array(indices))
32
+
33
+
34
+ class SparseTextEmbeddingBase(ModelManagement[SparseModelDescription]):
35
+ def __init__(
36
+ self,
37
+ model_name: str,
38
+ cache_dir: str | None = None,
39
+ threads: int | None = None,
40
+ **kwargs: Any,
41
+ ):
42
+ self.model_name = model_name
43
+ self.cache_dir = cache_dir
44
+ self.threads = threads
45
+ self._local_files_only = kwargs.pop("local_files_only", False)
46
+
47
+ def embed(
48
+ self,
49
+ documents: str | Iterable[str],
50
+ batch_size: int = 256,
51
+ parallel: int | None = None,
52
+ **kwargs: Any,
53
+ ) -> Iterable[SparseEmbedding]:
54
+ raise NotImplementedError()
55
+
56
+ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
57
+ """
58
+ Embeds a list of text passages into a list of embeddings.
59
+
60
+ Args:
61
+ texts (Iterable[str]): The list of texts to embed.
62
+ **kwargs: Additional keyword argument to pass to the embed method.
63
+
64
+ Yields:
65
+ Iterable[SparseEmbedding]: The sparse embeddings.
66
+ """
67
+
68
+ # This is model-specific, so that different models can have specialized implementations
69
+ yield from self.embed(texts, **kwargs)
70
+
71
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
72
+ """
73
+ Embeds queries
74
+
75
+ Args:
76
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
77
+
78
+ Returns:
79
+ Iterable[SparseEmbedding]: The sparse embeddings.
80
+ """
81
+
82
+ # This is model-specific, so that different models can have specialized implementations
83
+ if isinstance(query, str):
84
+ yield from self.embed([query], **kwargs)
85
+ else:
86
+ yield from self.embed(query, **kwargs)
87
+
88
+ def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
89
+ """Returns the number of tokens in the texts."""
90
+ raise NotImplementedError("Subclasses must implement this method")
@@ -0,0 +1,143 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+ from dataclasses import asdict
3
+
4
+ from fastembed.common import OnnxProvider
5
+ from fastembed.common.types import Device
6
+ from fastembed.sparse.bm25 import Bm25
7
+ from fastembed.sparse.bm42 import Bm42
8
+ from fastembed.sparse.minicoil import MiniCOIL
9
+ from fastembed.sparse.sparse_embedding_base import (
10
+ SparseEmbedding,
11
+ SparseTextEmbeddingBase,
12
+ )
13
+ from fastembed.sparse.splade_pp import SpladePP
14
+ import warnings
15
+ from fastembed.common.model_description import SparseModelDescription
16
+
17
+
18
+ class SparseTextEmbedding(SparseTextEmbeddingBase):
19
+ EMBEDDINGS_REGISTRY: list[Type[SparseTextEmbeddingBase]] = [SpladePP, Bm42, Bm25, MiniCOIL]
20
+
21
+ @classmethod
22
+ def list_supported_models(cls) -> list[dict[str, Any]]:
23
+ """
24
+ Lists the supported models.
25
+
26
+ Returns:
27
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
28
+
29
+ Example:
30
+ ```
31
+ [
32
+ {
33
+ "model": "prithvida/SPLADE_PP_en_v1",
34
+ "vocab_size": 30522,
35
+ "description": "Independent Implementation of SPLADE++ Model for English",
36
+ "license": "apache-2.0",
37
+ "size_in_GB": 0.532,
38
+ "sources": {
39
+ "hf": "qdrant/SPLADE_PP_en_v1",
40
+ },
41
+ }
42
+ ]
43
+ ```
44
+ """
45
+ return [asdict(model) for model in cls._list_supported_models()]
46
+
47
+ @classmethod
48
+ def _list_supported_models(cls) -> list[SparseModelDescription]:
49
+ result: list[SparseModelDescription] = []
50
+ for embedding in cls.EMBEDDINGS_REGISTRY:
51
+ result.extend(embedding._list_supported_models())
52
+ return result
53
+
54
+ def __init__(
55
+ self,
56
+ model_name: str,
57
+ cache_dir: str | None = None,
58
+ threads: int | None = None,
59
+ providers: Sequence[OnnxProvider] | None = None,
60
+ cuda: bool | Device = Device.AUTO,
61
+ device_ids: list[int] | None = None,
62
+ lazy_load: bool = False,
63
+ **kwargs: Any,
64
+ ):
65
+ super().__init__(model_name, cache_dir, threads, **kwargs)
66
+ if model_name.lower() == "prithvida/Splade_PP_en_v1".lower():
67
+ warnings.warn(
68
+ "The right spelling is prithivida/Splade_PP_en_v1. "
69
+ "Support of this name will be removed soon, please fix the model_name",
70
+ DeprecationWarning,
71
+ stacklevel=2,
72
+ )
73
+ model_name = "prithivida/Splade_PP_en_v1"
74
+
75
+ for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
76
+ supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
77
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
78
+ self.model = EMBEDDING_MODEL_TYPE(
79
+ model_name,
80
+ cache_dir,
81
+ threads=threads,
82
+ providers=providers,
83
+ cuda=cuda,
84
+ device_ids=device_ids,
85
+ lazy_load=lazy_load,
86
+ **kwargs,
87
+ )
88
+ return
89
+
90
+ raise ValueError(
91
+ f"Model {model_name} is not supported in SparseTextEmbedding."
92
+ "Please check the supported models using `SparseTextEmbedding.list_supported_models()`"
93
+ )
94
+
95
+ def embed(
96
+ self,
97
+ documents: str | Iterable[str],
98
+ batch_size: int = 256,
99
+ parallel: int | None = None,
100
+ **kwargs: Any,
101
+ ) -> Iterable[SparseEmbedding]:
102
+ """
103
+ Encode a list of documents into list of embeddings.
104
+ We use mean pooling with attention so that the model can handle variable-length inputs.
105
+
106
+ Args:
107
+ documents: Iterator of documents or single document to embed
108
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
109
+ parallel:
110
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
111
+ If 0, use all available cores.
112
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
113
+
114
+ Returns:
115
+ List of embeddings, one per document
116
+ """
117
+ yield from self.model.embed(documents, batch_size, parallel, **kwargs)
118
+
119
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
120
+ """
121
+ Embeds queries
122
+
123
+ Args:
124
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
125
+
126
+ Returns:
127
+ Iterable[SparseEmbedding]: The sparse embeddings.
128
+ """
129
+ yield from self.model.query_embed(query, **kwargs)
130
+
131
+ def token_count(
132
+ self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
133
+ ) -> int:
134
+ """Returns the number of tokens in the texts.
135
+
136
+ Args:
137
+ texts (str | Iterable[str]): The list of texts to embed.
138
+ batch_size (int): Batch size for encoding
139
+
140
+ Returns:
141
+ int: Sum of number of tokens in the texts.
142
+ """
143
+ return self.model.token_count(texts, batch_size=batch_size, **kwargs)