fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,301 @@
1
+ import string
2
+ from typing import Any, Iterable, Sequence, Type
3
+
4
+ import numpy as np
5
+ from tokenizers import Encoding, Tokenizer
6
+
7
+ from fastembed.common.preprocessor_utils import load_tokenizer
8
+ from fastembed.common.types import NumpyArray, Device
9
+ from fastembed.common import OnnxProvider
10
+ from fastembed.common.onnx_model import OnnxOutputContext
11
+ from fastembed.common.utils import define_cache_dir, iter_batch
12
+ from fastembed.late_interaction.late_interaction_embedding_base import (
13
+ LateInteractionTextEmbeddingBase,
14
+ )
15
+ from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
16
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
17
+
18
+ supported_colbert_models: list[DenseModelDescription] = [
19
+ DenseModelDescription(
20
+ model="colbert-ir/colbertv2.0",
21
+ dim=128,
22
+ description="Text embeddings, Unimodal (text), English, 512 input tokens truncation, 2023 year",
23
+ license="mit",
24
+ size_in_GB=0.44,
25
+ sources=ModelSource(hf="colbert-ir/colbertv2.0"),
26
+ model_file="model.onnx",
27
+ ),
28
+ DenseModelDescription(
29
+ model="answerdotai/answerai-colbert-small-v1",
30
+ dim=96,
31
+ description="Text embeddings, Unimodal (text), English, 512 input tokens truncation, 2024 year",
32
+ license="apache-2.0",
33
+ size_in_GB=0.13,
34
+ sources=ModelSource(hf="answerdotai/answerai-colbert-small-v1"),
35
+ model_file="vespa_colbert.onnx",
36
+ ),
37
+ ]
38
+
39
+
40
+ class Colbert(LateInteractionTextEmbeddingBase, OnnxTextModel[NumpyArray]):
41
+ QUERY_MARKER_TOKEN_ID = 1
42
+ DOCUMENT_MARKER_TOKEN_ID = 2
43
+ MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
44
+ MASK_TOKEN = "[MASK]"
45
+
46
+ def _post_process_onnx_output(
47
+ self, output: OnnxOutputContext, is_doc: bool = True, **kwargs: Any
48
+ ) -> Iterable[NumpyArray]:
49
+ if not is_doc:
50
+ for embedding in output.model_output:
51
+ yield embedding
52
+ else:
53
+ if output.input_ids is None or output.attention_mask is None:
54
+ raise ValueError(
55
+ "input_ids and attention_mask must be provided for document post-processing"
56
+ )
57
+
58
+ for i, token_sequence in enumerate(output.input_ids):
59
+ for j, token_id in enumerate(token_sequence): # type: ignore
60
+ if token_id in self.skip_list or token_id == self.pad_token_id:
61
+ output.attention_mask[i, j] = 0
62
+
63
+ output.model_output *= np.expand_dims(output.attention_mask, 2)
64
+ norm = np.linalg.norm(output.model_output, ord=2, axis=2, keepdims=True)
65
+ norm_clamped = np.maximum(norm, 1e-12)
66
+ output.model_output /= norm_clamped
67
+
68
+ for embedding, attention_mask in zip(output.model_output, output.attention_mask):
69
+ yield embedding[attention_mask == 1]
70
+
71
+ def _preprocess_onnx_input(
72
+ self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
73
+ ) -> dict[str, NumpyArray]:
74
+ marker_token = self.DOCUMENT_MARKER_TOKEN_ID if is_doc else self.QUERY_MARKER_TOKEN_ID
75
+ onnx_input["input_ids"] = np.insert(
76
+ onnx_input["input_ids"].astype(np.int64), 1, marker_token, axis=1
77
+ )
78
+ onnx_input["attention_mask"] = np.insert(
79
+ onnx_input["attention_mask"].astype(np.int64), 1, 1, axis=1
80
+ )
81
+ return onnx_input
82
+
83
+ def tokenize(self, documents: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]:
84
+ return (
85
+ self._tokenize_documents(documents=documents)
86
+ if is_doc
87
+ else self._tokenize_query(query=next(iter(documents)))
88
+ )
89
+
90
+ def _tokenize_query(self, query: str) -> list[Encoding]:
91
+ assert self.query_tokenizer is not None
92
+ encoded = self.query_tokenizer.encode_batch([query])
93
+ return encoded
94
+
95
+ def _tokenize_documents(self, documents: list[str]) -> list[Encoding]:
96
+ encoded = self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
97
+ return encoded
98
+
99
+ def token_count(
100
+ self,
101
+ texts: str | Iterable[str],
102
+ batch_size: int = 1024,
103
+ is_doc: bool = True,
104
+ include_extension: bool = False,
105
+ **kwargs: Any,
106
+ ) -> int:
107
+ if not hasattr(self, "model") or self.model is None:
108
+ self.load_onnx_model() # loads the tokenizer as well
109
+ token_num = 0
110
+ texts = [texts] if isinstance(texts, str) else texts
111
+ tokenizer = self.tokenizer if is_doc else self.query_tokenizer
112
+ assert tokenizer is not None
113
+ for batch in iter_batch(texts, batch_size):
114
+ for tokens in tokenizer.encode_batch(batch):
115
+ if is_doc:
116
+ token_num += sum(tokens.attention_mask)
117
+ else:
118
+ attend_count = sum(tokens.attention_mask)
119
+ if include_extension:
120
+ token_num += max(attend_count, self.MIN_QUERY_LENGTH)
121
+
122
+ else:
123
+ token_num += attend_count
124
+ if include_extension:
125
+ token_num += len(
126
+ batch
127
+ ) # add 1 for each cls.DOC_MARKER_TOKEN_ID or cls.QUERY_MARKER_TOKEN_ID
128
+
129
+ return token_num
130
+
131
+ @classmethod
132
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
133
+ """Lists the supported models.
134
+
135
+ Returns:
136
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
137
+ """
138
+ return supported_colbert_models
139
+
140
+ def __init__(
141
+ self,
142
+ model_name: str,
143
+ cache_dir: str | None = None,
144
+ threads: int | None = None,
145
+ providers: Sequence[OnnxProvider] | None = None,
146
+ cuda: bool | Device = Device.AUTO,
147
+ device_ids: list[int] | None = None,
148
+ lazy_load: bool = False,
149
+ device_id: int | None = None,
150
+ specific_model_path: str | None = None,
151
+ **kwargs: Any,
152
+ ):
153
+ """
154
+ Args:
155
+ model_name (str): The name of the model to use.
156
+ cache_dir (str, optional): The path to the cache directory.
157
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
158
+ Defaults to `fastembed_cache` in the system's temp directory.
159
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
160
+ providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
161
+ Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
162
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
163
+ Defaults to Device.AUTO.
164
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
165
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
166
+ with `providers`. Defaults to None.
167
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
168
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
169
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
170
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
171
+
172
+ Raises:
173
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
174
+ """
175
+
176
+ super().__init__(model_name, cache_dir, threads, **kwargs)
177
+ self.providers = providers
178
+ self.lazy_load = lazy_load
179
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
180
+
181
+ # List of device ids, that can be used for data parallel processing in workers
182
+ self.device_ids = device_ids
183
+ self.cuda = cuda
184
+
185
+ # This device_id will be used if we need to load model in current process
186
+ self.device_id: int | None = None
187
+ if device_id is not None:
188
+ self.device_id = device_id
189
+ elif self.device_ids is not None:
190
+ self.device_id = self.device_ids[0]
191
+
192
+ self.model_description = self._get_model_description(model_name)
193
+ self.cache_dir = str(define_cache_dir(cache_dir))
194
+
195
+ self._specific_model_path = specific_model_path
196
+ self._model_dir = self.download_model(
197
+ self.model_description,
198
+ self.cache_dir,
199
+ local_files_only=self._local_files_only,
200
+ specific_model_path=self._specific_model_path,
201
+ )
202
+ self.mask_token_id: int | None = None
203
+ self.pad_token_id: int | None = None
204
+ self.skip_list: set[int] = set()
205
+
206
+ self.query_tokenizer: Tokenizer | None = None
207
+
208
+ if not self.lazy_load:
209
+ self.load_onnx_model()
210
+
211
+ def load_onnx_model(self) -> None:
212
+ self._load_onnx_model(
213
+ model_dir=self._model_dir,
214
+ model_file=self.model_description.model_file,
215
+ threads=self.threads,
216
+ providers=self.providers,
217
+ cuda=self.cuda,
218
+ device_id=self.device_id,
219
+ extra_session_options=self._extra_session_options,
220
+ )
221
+ self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir)
222
+
223
+ assert self.tokenizer is not None
224
+ self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN]
225
+ self.pad_token_id = self.tokenizer.padding["pad_id"]
226
+ self.skip_list = {
227
+ self.tokenizer.encode(symbol, add_special_tokens=False).ids[0]
228
+ for symbol in string.punctuation
229
+ }
230
+ current_max_length = self.tokenizer.truncation["max_length"]
231
+ # ensure not to overflow after adding document-marker
232
+ self.tokenizer.enable_truncation(max_length=current_max_length - 1)
233
+ self.query_tokenizer.enable_truncation(max_length=current_max_length - 1)
234
+ self.query_tokenizer.enable_padding(
235
+ pad_token=self.MASK_TOKEN,
236
+ pad_id=self.mask_token_id,
237
+ length=self.MIN_QUERY_LENGTH,
238
+ )
239
+
240
+ def embed(
241
+ self,
242
+ documents: str | Iterable[str],
243
+ batch_size: int = 256,
244
+ parallel: int | None = None,
245
+ **kwargs: Any,
246
+ ) -> Iterable[NumpyArray]:
247
+ """
248
+ Encode a list of documents into list of embeddings.
249
+ We use mean pooling with attention so that the model can handle variable-length inputs.
250
+
251
+ Args:
252
+ documents: Iterator of documents or single document to embed
253
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
254
+ parallel:
255
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
256
+ If 0, use all available cores.
257
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
258
+
259
+ Returns:
260
+ List of embeddings, one per document
261
+ """
262
+ yield from self._embed_documents(
263
+ model_name=self.model_name,
264
+ cache_dir=str(self.cache_dir),
265
+ documents=documents,
266
+ batch_size=batch_size,
267
+ parallel=parallel,
268
+ providers=self.providers,
269
+ cuda=self.cuda,
270
+ device_ids=self.device_ids,
271
+ local_files_only=self._local_files_only,
272
+ specific_model_path=self._specific_model_path,
273
+ extra_session_options=self._extra_session_options,
274
+ **kwargs,
275
+ )
276
+
277
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
278
+ if isinstance(query, str):
279
+ query = [query]
280
+
281
+ if not hasattr(self, "model") or self.model is None:
282
+ self.load_onnx_model()
283
+
284
+ for text in query:
285
+ yield from self._post_process_onnx_output(
286
+ self.onnx_embed([text], is_doc=False), is_doc=False
287
+ )
288
+
289
+ @classmethod
290
+ def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
291
+ return ColbertEmbeddingWorker
292
+
293
+
294
+ class ColbertEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
295
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Colbert:
296
+ return Colbert(
297
+ model_name=model_name,
298
+ cache_dir=cache_dir,
299
+ threads=1,
300
+ **kwargs,
301
+ )
@@ -0,0 +1,58 @@
1
+ from typing import Any, Type
2
+
3
+ from fastembed.common.types import NumpyArray
4
+ from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker
5
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
6
+
7
+ supported_jina_colbert_models: list[DenseModelDescription] = [
8
+ DenseModelDescription(
9
+ model="jinaai/jina-colbert-v2",
10
+ dim=128,
11
+ description="New model that expands capabilities of colbert-v1 with multilingual and context length of 8192, 2024 year",
12
+ license="cc-by-nc-4.0",
13
+ size_in_GB=2.24,
14
+ sources=ModelSource(hf="jinaai/jina-colbert-v2"),
15
+ model_file="onnx/model.onnx",
16
+ additional_files=["onnx/model.onnx_data"],
17
+ )
18
+ ]
19
+
20
+
21
+ class JinaColbert(Colbert):
22
+ QUERY_MARKER_TOKEN_ID = 250002
23
+ DOCUMENT_MARKER_TOKEN_ID = 250003
24
+ MIN_QUERY_LENGTH = 31 # it's 32, we add one additional special token in the beginning
25
+ MASK_TOKEN = "<mask>"
26
+
27
+ @classmethod
28
+ def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]:
29
+ return JinaColbertEmbeddingWorker
30
+
31
+ @classmethod
32
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
33
+ """Lists the supported models.
34
+
35
+ Returns:
36
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
37
+ """
38
+ return supported_jina_colbert_models
39
+
40
+ def _preprocess_onnx_input(
41
+ self, onnx_input: dict[str, NumpyArray], is_doc: bool = True, **kwargs: Any
42
+ ) -> dict[str, NumpyArray]:
43
+ onnx_input = super()._preprocess_onnx_input(onnx_input, is_doc)
44
+
45
+ # the attention mask for jina-colbert-v2 is always 1 in queries
46
+ if not is_doc:
47
+ onnx_input["attention_mask"][:] = 1
48
+ return onnx_input
49
+
50
+
51
+ class JinaColbertEmbeddingWorker(ColbertEmbeddingWorker):
52
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> JinaColbert:
53
+ return JinaColbert(
54
+ model_name=model_name,
55
+ cache_dir=cache_dir,
56
+ threads=1,
57
+ **kwargs,
58
+ )
@@ -0,0 +1,80 @@
1
+ from typing import Iterable, Any
2
+
3
+ from fastembed.common.model_description import DenseModelDescription
4
+ from fastembed.common.types import NumpyArray
5
+ from fastembed.common.model_management import ModelManagement
6
+
7
+
8
+ class LateInteractionTextEmbeddingBase(ModelManagement[DenseModelDescription]):
9
+ def __init__(
10
+ self,
11
+ model_name: str,
12
+ cache_dir: str | None = None,
13
+ threads: int | None = None,
14
+ **kwargs: Any,
15
+ ):
16
+ self.model_name = model_name
17
+ self.cache_dir = cache_dir
18
+ self.threads = threads
19
+ self._local_files_only = kwargs.pop("local_files_only", False)
20
+ self._embedding_size: int | None = None
21
+
22
+ def embed(
23
+ self,
24
+ documents: str | Iterable[str],
25
+ batch_size: int = 256,
26
+ parallel: int | None = None,
27
+ **kwargs: Any,
28
+ ) -> Iterable[NumpyArray]:
29
+ raise NotImplementedError()
30
+
31
+ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
32
+ """
33
+ Embeds a list of text passages into a list of embeddings.
34
+
35
+ Args:
36
+ texts (Iterable[str]): The list of texts to embed.
37
+ **kwargs: Additional keyword argument to pass to the embed method.
38
+
39
+ Yields:
40
+ Iterable[NdArray]: The embeddings.
41
+ """
42
+
43
+ # This is model-specific, so that different models can have specialized implementations
44
+ yield from self.embed(texts, **kwargs)
45
+
46
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
47
+ """
48
+ Embeds queries
49
+
50
+ Args:
51
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
52
+
53
+ Returns:
54
+ Iterable[NdArray]: The embeddings.
55
+ """
56
+
57
+ # This is model-specific, so that different models can have specialized implementations
58
+ if isinstance(query, str):
59
+ yield from self.embed([query], **kwargs)
60
+ else:
61
+ yield from self.embed(query, **kwargs)
62
+
63
+ @classmethod
64
+ def get_embedding_size(cls, model_name: str) -> int:
65
+ """Returns embedding size of the chosen model."""
66
+ raise NotImplementedError("Subclasses must implement this method")
67
+
68
+ @property
69
+ def embedding_size(self) -> int:
70
+ """Returns embedding size for the current model"""
71
+ raise NotImplementedError("Subclasses must implement this method")
72
+
73
+ def token_count(
74
+ self,
75
+ texts: str | Iterable[str],
76
+ batch_size: int = 1024,
77
+ **kwargs: Any,
78
+ ) -> int:
79
+ """Returns the number of tokens in the texts."""
80
+ raise NotImplementedError("Subclasses must implement this method")
@@ -0,0 +1,180 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+ from dataclasses import asdict
3
+
4
+ from fastembed.common.model_description import DenseModelDescription
5
+ from fastembed.common.types import NumpyArray, Device
6
+ from fastembed.common import OnnxProvider
7
+ from fastembed.late_interaction.colbert import Colbert
8
+ from fastembed.late_interaction.jina_colbert import JinaColbert
9
+ from fastembed.late_interaction.late_interaction_embedding_base import (
10
+ LateInteractionTextEmbeddingBase,
11
+ )
12
+
13
+
14
+ class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase):
15
+ EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert]
16
+
17
+ @classmethod
18
+ def list_supported_models(cls) -> list[dict[str, Any]]:
19
+ """
20
+ Lists the supported models.
21
+
22
+ Returns:
23
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
24
+
25
+ Example:
26
+ ```
27
+ [
28
+ {
29
+ "model": "colbert-ir/colbertv2.0",
30
+ "dim": 128,
31
+ "description": "Late interaction model",
32
+ "license": "mit",
33
+ "size_in_GB": 0.44,
34
+ "sources": {
35
+ "hf": "colbert-ir/colbertv2.0",
36
+ },
37
+ "model_file": "model.onnx",
38
+ },
39
+ ]
40
+ ```
41
+ """
42
+ return [asdict(model) for model in cls._list_supported_models()]
43
+
44
+ @classmethod
45
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
46
+ result: list[DenseModelDescription] = []
47
+ for embedding in cls.EMBEDDINGS_REGISTRY:
48
+ result.extend(embedding._list_supported_models())
49
+ return result
50
+
51
+ def __init__(
52
+ self,
53
+ model_name: str,
54
+ cache_dir: str | None = None,
55
+ threads: int | None = None,
56
+ providers: Sequence[OnnxProvider] | None = None,
57
+ cuda: bool | Device = Device.AUTO,
58
+ device_ids: list[int] | None = None,
59
+ lazy_load: bool = False,
60
+ **kwargs: Any,
61
+ ):
62
+ super().__init__(model_name, cache_dir, threads, **kwargs)
63
+ for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
64
+ supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
65
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
66
+ self.model = EMBEDDING_MODEL_TYPE(
67
+ model_name,
68
+ cache_dir,
69
+ threads=threads,
70
+ providers=providers,
71
+ cuda=cuda,
72
+ device_ids=device_ids,
73
+ lazy_load=lazy_load,
74
+ **kwargs,
75
+ )
76
+ return
77
+
78
+ raise ValueError(
79
+ f"Model {model_name} is not supported in LateInteractionTextEmbedding."
80
+ "Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`"
81
+ )
82
+
83
+ @property
84
+ def embedding_size(self) -> int:
85
+ """Get the embedding size of the current model"""
86
+ if self._embedding_size is None:
87
+ self._embedding_size = self.get_embedding_size(self.model_name)
88
+ return self._embedding_size
89
+
90
+ @classmethod
91
+ def get_embedding_size(cls, model_name: str) -> int:
92
+ """Get the embedding size of the passed model
93
+
94
+ Args:
95
+ model_name (str): The name of the model to get embedding size for.
96
+
97
+ Returns:
98
+ int: The size of the embedding.
99
+
100
+ Raises:
101
+ ValueError: If the model name is not found in the supported models.
102
+ """
103
+ descriptions = cls._list_supported_models()
104
+ embedding_size: int | None = None
105
+ for description in descriptions:
106
+ if description.model.lower() == model_name.lower():
107
+ embedding_size = description.dim
108
+ break
109
+ if embedding_size is None:
110
+ model_names = [description.model for description in descriptions]
111
+ raise ValueError(
112
+ f"Embedding size for model {model_name} was None. "
113
+ f"Available model names: {model_names}"
114
+ )
115
+ return embedding_size
116
+
117
+ def embed(
118
+ self,
119
+ documents: str | Iterable[str],
120
+ batch_size: int = 256,
121
+ parallel: int | None = None,
122
+ **kwargs: Any,
123
+ ) -> Iterable[NumpyArray]:
124
+ """
125
+ Encode a list of documents into list of embeddings.
126
+ We use mean pooling with attention so that the model can handle variable-length inputs.
127
+
128
+ Args:
129
+ documents: Iterator of documents or single document to embed
130
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
131
+ parallel:
132
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
133
+ If 0, use all available cores.
134
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
135
+
136
+ Returns:
137
+ List of embeddings, one per document
138
+ """
139
+ yield from self.model.embed(documents, batch_size, parallel, **kwargs)
140
+
141
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
142
+ """
143
+ Embeds queries
144
+
145
+ Args:
146
+ query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
147
+
148
+ Returns:
149
+ Iterable[NdArray]: The embeddings.
150
+ """
151
+
152
+ # This is model-specific, so that different models can have specialized implementations
153
+ yield from self.model.query_embed(query, **kwargs)
154
+
155
+ def token_count(
156
+ self,
157
+ texts: str | Iterable[str],
158
+ batch_size: int = 1024,
159
+ is_doc: bool = True,
160
+ include_extension: bool = False,
161
+ **kwargs: Any,
162
+ ) -> int:
163
+ """Returns the number of tokens in the texts.
164
+
165
+ Args:
166
+ texts (str | Iterable[str]): The list of texts to embed.
167
+ batch_size (int): Batch size for encoding
168
+ is_doc (bool): Whether the texts are documents (disable embedding a query with include_mask=True).
169
+ include_extension (bool): Turn on to count DOC / QUERY marker tokens, and [MASK] token in query mode.
170
+
171
+ Returns:
172
+ int: Sum of number of tokens in the texts.
173
+ """
174
+ return self.model.token_count(
175
+ texts,
176
+ batch_size=batch_size,
177
+ is_doc=is_doc,
178
+ include_extension=include_extension,
179
+ **kwargs,
180
+ )