fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,362 @@
1
+ import numpy as np
2
+
3
+ from fastembed.common.types import NumpyArray
4
+ from fastembed.late_interaction.late_interaction_embedding_base import (
5
+ LateInteractionTextEmbeddingBase,
6
+ )
7
+ from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
8
+ LateInteractionMultimodalEmbeddingBase,
9
+ )
10
+
11
+
12
+ MultiVectorModel = LateInteractionTextEmbeddingBase | LateInteractionMultimodalEmbeddingBase
13
+ MAX_HAMMING_DISTANCE = 65 # 64 bits + 1
14
+ POPCOUNT_LUT = np.array([bin(x).count("1") for x in range(256)], dtype=np.uint8)
15
+
16
+
17
+ def hamming_distance_matrix(ids: np.ndarray) -> np.ndarray:
18
+ """Compute full Hamming distance matrix
19
+
20
+ Args:
21
+ ids: shape (n,) - array of ids, only size of the array matters
22
+
23
+ Return:
24
+ np.ndarray (n, n) - hamming distance matrix
25
+ """
26
+ n = len(ids)
27
+ xor_vals = np.bitwise_xor(ids[:, None], ids[None, :]) # (n, n) uint64
28
+ bytes_view = xor_vals.view(np.uint8).reshape(n, n, 8) # (n, n, 8)
29
+ return POPCOUNT_LUT[bytes_view].sum(axis=2)
30
+
31
+
32
+ class SimHashProjection:
33
+ """
34
+ SimHash projection component for MUVERA clustering.
35
+
36
+ This class implements locality-sensitive hashing using random hyperplanes
37
+ to partition the vector space into 2^k_sim clusters. Each vector is assigned
38
+ to a cluster based on which side of k_sim random hyperplanes it falls on.
39
+
40
+ Attributes:
41
+ k_sim (int): Number of SimHash functions (hyperplanes)
42
+ dim (int): Dimensionality of input vectors
43
+ simhash_vectors (np.ndarray): Random hyperplane normal vectors of shape (dim, k_sim)
44
+ """
45
+
46
+ def __init__(self, k_sim: int, dim: int, random_generator: np.random.Generator):
47
+ """
48
+ Initialize SimHash projection with random hyperplanes.
49
+
50
+ Args:
51
+ k_sim (int): Number of SimHash functions, determines 2^k_sim clusters
52
+ dim (int): Dimensionality of input vectors
53
+ random_generator (np.random.Generator): Random number generator for reproducibility
54
+ """
55
+ self.k_sim = k_sim
56
+ self.dim = dim
57
+ # Generate k_sim random hyperplanes (normal vectors) from standard normal distribution
58
+ self.simhash_vectors = random_generator.normal(size=(dim, k_sim))
59
+
60
+ def get_cluster_ids(self, vectors: np.ndarray) -> np.ndarray:
61
+ """
62
+ Compute the cluster IDs for a given vector using SimHash.
63
+
64
+ The cluster ID is determined by computing the dot product of the vector
65
+ with each hyperplane normal vector, taking the sign, and interpreting
66
+ the resulting binary string as an integer.
67
+
68
+ Args:
69
+ vectors (np.ndarray): Input vectors of shape (n, dim,)
70
+
71
+ Returns:
72
+ np.ndarray: Cluster IDs in range [0, 2^k_sim - 1]
73
+
74
+ Raises:
75
+ AssertionError: If a vector shape doesn't match expected dimensionality
76
+ """
77
+ dot_product = (
78
+ vectors @ self.simhash_vectors
79
+ ) # (token_num, dim) x (dim, k_sim) -> (token_num, k_sim)
80
+ cluster_ids = (dot_product > 0) @ (1 << np.arange(self.k_sim))
81
+ return cluster_ids
82
+
83
+
84
+ class Muvera:
85
+ """
86
+ MUVERA (Multi-Vector Retrieval Architecture) algorithm implementation.
87
+
88
+ This class creates Fixed Dimensional Encodings (FDEs) from variable-length
89
+ sequences of vectors by using SimHash clustering and random projections.
90
+ The process involves:
91
+ 1. Clustering vectors using multiple SimHash projections
92
+ 2. Computing cluster centers (with different strategies for docs vs queries)
93
+ 3. Applying random projections for dimensionality reduction
94
+ 4. Concatenating results from all projections
95
+
96
+ Attributes:
97
+ k_sim (int): Number of SimHash functions per projection
98
+ dim (int): Input vector dimensionality
99
+ dim_proj (int): Output dimensionality after random projection
100
+ r_reps (int): Number of random projection repetitions
101
+ random_seed (int): Random seed for consistent random matrix generation
102
+ simhash_projections (List[SimHashProjection]): SimHash instances for clustering
103
+ dim_reduction_projections (np.ndarray): Random projection matrices of shape (R_reps, d, d_proj)
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ k_sim: int = 5,
110
+ dim_proj: int = 16,
111
+ r_reps: int = 20,
112
+ random_seed: int = 42,
113
+ ):
114
+ """
115
+ Initialize MUVERA algorithm with specified parameters.
116
+
117
+ Args:
118
+ dim (int): Dimensionality of individual input vectors
119
+ k_sim (int, optional): Number of SimHash functions (creates 2^k_sim clusters).
120
+ Defaults to 5.
121
+ dim_proj (int, optional): Dimensionality after random projection (must be <= dim).
122
+ Defaults to 16.
123
+ r_reps (int, optional): Number of random projection repetitions for robustness.
124
+ Defaults to 20.
125
+ random_seed (int, optional): Seed for random number generator to ensure
126
+ reproducible results. Defaults to 42.
127
+
128
+ Raises:
129
+ ValueError: If dim_proj > dim (cannot project to higher dimensionality)
130
+ """
131
+ if dim_proj > dim:
132
+ raise ValueError(
133
+ f"Cannot project to a higher dimensionality (dim_proj={dim_proj} > dim={dim})"
134
+ )
135
+
136
+ self.k_sim = k_sim
137
+ self.dim = dim
138
+ self.dim_proj = dim_proj
139
+ self.r_reps = r_reps
140
+ # Create r_reps independent SimHash projections for robustness
141
+ generator = np.random.default_rng(random_seed)
142
+ self.simhash_projections = [
143
+ SimHashProjection(k_sim=self.k_sim, dim=self.dim, random_generator=generator)
144
+ for _ in range(r_reps)
145
+ ]
146
+ # Random projection matrices with entries from {-1, +1} for each repetition
147
+ self.dim_reduction_projections = generator.choice([-1, 1], size=(r_reps, dim, dim_proj))
148
+
149
+ @classmethod
150
+ def from_multivector_model(
151
+ cls,
152
+ model: MultiVectorModel,
153
+ k_sim: int = 5,
154
+ dim_proj: int = 16,
155
+ r_reps: int = 20, # noqa[naming]
156
+ random_seed: int = 42,
157
+ ) -> "Muvera":
158
+ """
159
+ Create a Muvera instance from a multi-vector embedding model.
160
+
161
+ This class method provides a convenient way to initialize a MUVERA
162
+ that is compatible with a given multi-vector model by automatically extracting
163
+ the embedding dimensionality from the model.
164
+
165
+ Args:
166
+ model (MultiVectorModel): A late interaction text or multimodal embedding model
167
+ that provides multi-vector embeddings. Must have an
168
+ `embedding_size` attribute specifying the dimensionality
169
+ of individual vectors.
170
+ k_sim (int, optional): Number of SimHash functions (creates 2^k_sim clusters).
171
+ Defaults to 5.
172
+ dim_proj (int, optional): Dimensionality after random projection (must be <= model's
173
+ embedding_size). Defaults to 16.
174
+ r_reps (int, optional): Number of random projection repetitions for robustness.
175
+ Defaults to 20.
176
+ random_seed (int, optional): Seed for random number generator to ensure
177
+ reproducible results. Defaults to 42.
178
+
179
+ Returns:
180
+ Muvera: A configured MUVERA instance ready to process embeddings from the given model.
181
+
182
+ Raises:
183
+ ValueError: If dim_proj > model.embedding_size (cannot project to higher dimensionality)
184
+
185
+ Example:
186
+ >>> from fastembed import LateInteractionTextEmbedding
187
+ >>> model = LateInteractionTextEmbedding(model_name="colbert-ir/colbertv2.0")
188
+ >>> muvera = Muvera.from_multivector_model(
189
+ ... model=model,
190
+ ... k_sim=6,
191
+ ... dim_proj=32
192
+ ... )
193
+ >>> # Now use postprocessor with embeddings from the model
194
+ >>> embeddings = np.array(list(model.embed(["sample text"])))
195
+ >>> fde = muvera.process_document(embeddings[0])
196
+ """
197
+ return cls(
198
+ dim=model.embedding_size,
199
+ k_sim=k_sim,
200
+ dim_proj=dim_proj,
201
+ r_reps=r_reps,
202
+ random_seed=random_seed,
203
+ )
204
+
205
+ def _get_output_dimension(self) -> int:
206
+ """
207
+ Get the output dimension of the MUVERA algorithm.
208
+
209
+ Returns:
210
+ int: Output dimension (r_reps * num_partitions * dim_proj) where b = 2^k_sim
211
+ """
212
+ num_partitions = 2**self.k_sim
213
+ return self.r_reps * num_partitions * self.dim_proj
214
+
215
+ @property
216
+ def embedding_size(self) -> int:
217
+ return self._get_output_dimension()
218
+
219
+ def process_document(self, vectors: NumpyArray) -> NumpyArray:
220
+ """
221
+ Encode a document's vectors into a Fixed Dimensional Encoding (FDE).
222
+
223
+ Uses document-specific settings: normalizes cluster centers by vector count
224
+ and fills empty clusters using Hamming distance-based selection.
225
+
226
+ Args:
227
+ vectors (NumpyArray): Document vectors of shape (n_tokens, dim)
228
+
229
+ Returns:
230
+ NumpyArray: Fixed dimensional encodings of shape (r_reps * b * dim_proj,)
231
+ """
232
+ return self.process(vectors, fill_empty_clusters=True, normalize_by_count=True)
233
+
234
+ def process_query(self, vectors: NumpyArray) -> NumpyArray:
235
+ """
236
+ Encode a query's vectors into a Fixed Dimensional Encoding (FDE).
237
+
238
+ Uses query-specific settings: no normalization by count and no empty
239
+ cluster filling to preserve query vector magnitudes.
240
+
241
+ Args:
242
+ vectors (NumpyArray]): Query vectors of shape (n_tokens, dim)
243
+
244
+ Returns:
245
+ NumpyArray: Fixed dimensional encoding of shape (r_reps * b * dim_proj,)
246
+ """
247
+ return self.process(vectors, fill_empty_clusters=False, normalize_by_count=False)
248
+
249
+ def process(
250
+ self,
251
+ vectors: NumpyArray,
252
+ fill_empty_clusters: bool = True,
253
+ normalize_by_count: bool = True,
254
+ ) -> NumpyArray:
255
+ """
256
+ Core encoding method that transforms variable-length vector sequences into FDEs.
257
+
258
+ The encoding process:
259
+ 1. For each of r_reps random projections:
260
+ a. Assign vectors to clusters using SimHash
261
+ b. Compute cluster centers (sum of vectors in each cluster)
262
+ c. Optionally normalize by cluster size
263
+ d. Fill empty clusters using Hamming distance if requested
264
+ e. Apply random projection for dimensionality reduction
265
+ f. Flatten cluster centers into a vector
266
+ 2. Concatenate all projection results
267
+
268
+ Args:
269
+ vectors (np.ndarray): Input vectors of shape (n_vectors, dim)
270
+ fill_empty_clusters (bool): Whether to fill empty clusters using nearest
271
+ vectors based on Hamming distance of cluster IDs
272
+ normalize_by_count (bool): Whether to normalize cluster centers by the
273
+ number of vectors assigned to each cluster
274
+
275
+ Returns:
276
+ np.ndarray: Fixed dimensional encoding of shape (r_reps * b * dim_proj)
277
+ where B = 2^k_sim is the number of clusters
278
+
279
+ Raises:
280
+ AssertionError: If input vectors don't have expected dimensionality
281
+ """
282
+ assert (
283
+ vectors.shape[1] == self.dim
284
+ ), f"Expected vectors of shape (n, {self.dim}), got {vectors.shape}"
285
+
286
+ # Store results from each random projection
287
+ output_vectors = []
288
+
289
+ # num of space partitions in SimHash
290
+ num_partitions = 2**self.k_sim
291
+ cluster_center_ids = np.arange(num_partitions)
292
+ precomputed_hamming_matrix = (
293
+ hamming_distance_matrix(cluster_center_ids) if fill_empty_clusters else None
294
+ )
295
+
296
+ for projection_index, simhash in enumerate(self.simhash_projections):
297
+ # Initialize cluster centers and count vectors assigned to each cluster
298
+ cluster_centers = np.zeros((num_partitions, self.dim))
299
+ cluster_center_id_to_vectors: dict[int, list[int]] = {
300
+ cluster_center_id: [] for cluster_center_id in cluster_center_ids
301
+ }
302
+ cluster_vector_counts = None
303
+ empty_mask = None
304
+
305
+ # Assign each vector to its cluster and accumulate cluster centers
306
+ vector_cluster_ids = simhash.get_cluster_ids(vectors)
307
+ for cluster_id, (vec_idx, vec) in zip(vector_cluster_ids, enumerate(vectors)):
308
+ cluster_centers[cluster_id] += vec
309
+ cluster_center_id_to_vectors[cluster_id].append(vec_idx)
310
+
311
+ if normalize_by_count or fill_empty_clusters:
312
+ cluster_vector_counts = np.bincount(vector_cluster_ids, minlength=num_partitions)
313
+ empty_mask = cluster_vector_counts == 0
314
+
315
+ if normalize_by_count:
316
+ assert empty_mask is not None
317
+ assert cluster_vector_counts is not None
318
+ non_empty_mask = ~empty_mask
319
+ cluster_centers[non_empty_mask] /= cluster_vector_counts[non_empty_mask][:, None]
320
+
321
+ # Fill empty clusters using vectors with minimum Hamming distance
322
+ if fill_empty_clusters:
323
+ assert empty_mask is not None
324
+ assert precomputed_hamming_matrix is not None
325
+ masked_hamming = np.where(
326
+ empty_mask[None, :], MAX_HAMMING_DISTANCE, precomputed_hamming_matrix
327
+ )
328
+ nearest_non_empty = np.argmin(masked_hamming, axis=1)
329
+ fill_vectors = np.array(
330
+ [
331
+ vectors[cluster_center_id_to_vectors[cluster_id][0]]
332
+ for cluster_id in nearest_non_empty[empty_mask]
333
+ ]
334
+ ).reshape(-1, self.dim)
335
+ cluster_centers[empty_mask] = fill_vectors
336
+
337
+ # Apply random projection for dimensionality reduction if needed
338
+ if self.dim_proj < self.dim:
339
+ dim_reduction_projection = self.dim_reduction_projections[
340
+ projection_index
341
+ ] # Get projection matrix for this repetition
342
+ projected_centers = (1 / np.sqrt(self.dim_proj)) * (
343
+ cluster_centers @ dim_reduction_projection
344
+ )
345
+
346
+ # Flatten cluster centers into a single vector and add to output
347
+ output_vectors.append(projected_centers.flatten())
348
+ continue
349
+
350
+ # If no projection needed (dim_proj == dim), use original cluster centers
351
+ output_vectors.append(cluster_centers.flatten())
352
+
353
+ # Concatenate results from all R_reps projections into final FDE
354
+ return np.concatenate(output_vectors)
355
+
356
+
357
+ if __name__ == "__main__":
358
+ v_arrs = np.random.randn(10, 100, 128)
359
+ muvera = Muvera(128, 4, 8, 20, 42)
360
+
361
+ for v_arr in v_arrs:
362
+ muvera.process(v_arr) # type: ignore
fastembed/py.typed ADDED
@@ -0,0 +1 @@
1
+ partial
@@ -0,0 +1,3 @@
1
+ from fastembed.rerank.cross_encoder.text_cross_encoder import TextCrossEncoder
2
+
3
+ __all__ = ["TextCrossEncoder"]
@@ -0,0 +1,47 @@
1
+ from typing import Sequence, Any
2
+
3
+ from fastembed.common import OnnxProvider
4
+ from fastembed.common.model_description import BaseModelDescription
5
+ from fastembed.common.types import Device
6
+ from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
7
+
8
+
9
+ class CustomTextCrossEncoder(OnnxTextCrossEncoder):
10
+ SUPPORTED_MODELS: list[BaseModelDescription] = []
11
+
12
+ def __init__(
13
+ self,
14
+ model_name: str,
15
+ cache_dir: str | None = None,
16
+ threads: int | None = None,
17
+ providers: Sequence[OnnxProvider] | None = None,
18
+ cuda: bool | Device = Device.AUTO,
19
+ device_ids: list[int] | None = None,
20
+ lazy_load: bool = False,
21
+ device_id: int | None = None,
22
+ specific_model_path: str | None = None,
23
+ **kwargs: Any,
24
+ ):
25
+ super().__init__(
26
+ model_name=model_name,
27
+ cache_dir=cache_dir,
28
+ threads=threads,
29
+ providers=providers,
30
+ cuda=cuda,
31
+ device_ids=device_ids,
32
+ lazy_load=lazy_load,
33
+ device_id=device_id,
34
+ specific_model_path=specific_model_path,
35
+ **kwargs,
36
+ )
37
+
38
+ @classmethod
39
+ def _list_supported_models(cls) -> list[BaseModelDescription]:
40
+ return cls.SUPPORTED_MODELS
41
+
42
+ @classmethod
43
+ def add_model(
44
+ cls,
45
+ model_description: BaseModelDescription,
46
+ ) -> None:
47
+ cls.SUPPORTED_MODELS.append(model_description)
@@ -0,0 +1,239 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+
3
+ from loguru import logger
4
+
5
+ from fastembed.common import OnnxProvider
6
+ from fastembed.common.onnx_model import OnnxOutputContext
7
+ from fastembed.common.types import Device
8
+ from fastembed.common.utils import define_cache_dir
9
+ from fastembed.rerank.cross_encoder.onnx_text_model import (
10
+ OnnxCrossEncoderModel,
11
+ TextRerankerWorker,
12
+ )
13
+ from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
14
+ from fastembed.common.model_description import BaseModelDescription, ModelSource
15
+
16
+ supported_onnx_models: list[BaseModelDescription] = [
17
+ BaseModelDescription(
18
+ model="Xenova/ms-marco-MiniLM-L-6-v2",
19
+ description="MiniLM-L-6-v2 model optimized for re-ranking tasks.",
20
+ license="apache-2.0",
21
+ size_in_GB=0.08,
22
+ sources=ModelSource(hf="Xenova/ms-marco-MiniLM-L-6-v2"),
23
+ model_file="onnx/model.onnx",
24
+ ),
25
+ BaseModelDescription(
26
+ model="Xenova/ms-marco-MiniLM-L-12-v2",
27
+ description="MiniLM-L-12-v2 model optimized for re-ranking tasks.",
28
+ license="apache-2.0",
29
+ size_in_GB=0.12,
30
+ sources=ModelSource(hf="Xenova/ms-marco-MiniLM-L-12-v2"),
31
+ model_file="onnx/model.onnx",
32
+ ),
33
+ BaseModelDescription(
34
+ model="BAAI/bge-reranker-base",
35
+ description="BGE reranker base model for cross-encoder re-ranking.",
36
+ license="mit",
37
+ size_in_GB=1.04,
38
+ sources=ModelSource(hf="BAAI/bge-reranker-base"),
39
+ model_file="onnx/model.onnx",
40
+ ),
41
+ BaseModelDescription(
42
+ model="jinaai/jina-reranker-v1-tiny-en",
43
+ description="Designed for blazing-fast re-ranking with 8K context length and fewer parameters than jina-reranker-v1-turbo-en.",
44
+ license="apache-2.0",
45
+ size_in_GB=0.13,
46
+ sources=ModelSource(hf="jinaai/jina-reranker-v1-tiny-en"),
47
+ model_file="onnx/model.onnx",
48
+ ),
49
+ BaseModelDescription(
50
+ model="jinaai/jina-reranker-v1-turbo-en",
51
+ description="Designed for blazing-fast re-ranking with 8K context length.",
52
+ license="apache-2.0",
53
+ size_in_GB=0.15,
54
+ sources=ModelSource(hf="jinaai/jina-reranker-v1-turbo-en"),
55
+ model_file="onnx/model.onnx",
56
+ ),
57
+ BaseModelDescription(
58
+ model="jinaai/jina-reranker-v2-base-multilingual",
59
+ description="A multi-lingual reranker model for cross-encoder re-ranking with 1K context length and sliding window",
60
+ license="cc-by-nc-4.0",
61
+ size_in_GB=1.11,
62
+ sources=ModelSource(hf="jinaai/jina-reranker-v2-base-multilingual"),
63
+ model_file="onnx/model.onnx",
64
+ ),
65
+ ]
66
+
67
+
68
+ class OnnxTextCrossEncoder(TextCrossEncoderBase, OnnxCrossEncoderModel):
69
+ @classmethod
70
+ def _list_supported_models(cls) -> list[BaseModelDescription]:
71
+ """Lists the supported models.
72
+
73
+ Returns:
74
+ list[BaseModelDescription]: A list of BaseModelDescription objects containing the model information.
75
+ """
76
+ return supported_onnx_models
77
+
78
+ def __init__(
79
+ self,
80
+ model_name: str,
81
+ cache_dir: str | None = None,
82
+ threads: int | None = None,
83
+ providers: Sequence[OnnxProvider] | None = None,
84
+ cuda: bool | Device = Device.AUTO,
85
+ device_ids: list[int] | None = None,
86
+ lazy_load: bool = False,
87
+ device_id: int | None = None,
88
+ specific_model_path: str | None = None,
89
+ **kwargs: Any,
90
+ ):
91
+ """
92
+ Args:
93
+ model_name (str): The name of the model to use.
94
+ cache_dir (str, optional): The path to the cache directory.
95
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
96
+ Defaults to `fastembed_cache` in the system's temp directory.
97
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
98
+ providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
99
+ Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
100
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
101
+ Defaults to Device.AUTO.
102
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
103
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
104
+ with `providers`. Defaults to None.
105
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
106
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
107
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
108
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
109
+
110
+ Raises:
111
+ ValueError: If the model_name is not in the format <org>/<model> e.g. Xenova/ms-marco-MiniLM-L-6-v2.
112
+ """
113
+ super().__init__(model_name, cache_dir, threads, **kwargs)
114
+ self.providers = providers
115
+ self.lazy_load = lazy_load
116
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
117
+
118
+ # List of device ids, that can be used for data parallel processing in workers
119
+ self.device_ids = device_ids
120
+ self.cuda = cuda
121
+
122
+ if self.device_ids is not None and len(self.device_ids) > 1:
123
+ logger.warning(
124
+ "Parallel execution is currently not supported for cross encoders, "
125
+ f"only the first device will be used for inference: {self.device_ids[0]}."
126
+ )
127
+
128
+ # This device_id will be used if we need to load model in current process
129
+ self.device_id: int | None = None
130
+ if device_id is not None:
131
+ self.device_id = device_id
132
+ elif self.device_ids is not None:
133
+ self.device_id = self.device_ids[0]
134
+
135
+ self.model_description = self._get_model_description(model_name)
136
+ self.cache_dir = str(define_cache_dir(cache_dir))
137
+ self._specific_model_path = specific_model_path
138
+ self._model_dir = self.download_model(
139
+ self.model_description,
140
+ self.cache_dir,
141
+ local_files_only=self._local_files_only,
142
+ specific_model_path=self._specific_model_path,
143
+ )
144
+
145
+ if not self.lazy_load:
146
+ self.load_onnx_model()
147
+
148
+ def load_onnx_model(self) -> None:
149
+ self._load_onnx_model(
150
+ model_dir=self._model_dir,
151
+ model_file=self.model_description.model_file,
152
+ threads=self.threads,
153
+ providers=self.providers,
154
+ cuda=self.cuda,
155
+ device_id=self.device_id,
156
+ extra_session_options=self._extra_session_options,
157
+ )
158
+
159
+ def rerank(
160
+ self,
161
+ query: str,
162
+ documents: Iterable[str],
163
+ batch_size: int = 64,
164
+ **kwargs: Any,
165
+ ) -> Iterable[float]:
166
+ """Reranks documents based on their relevance to a given query.
167
+
168
+ Args:
169
+ query (str): The query string to which document relevance is calculated.
170
+ documents (Iterable[str]): Iterable of documents to be reranked.
171
+ batch_size (int, optional): The number of documents processed in each batch. Higher batch sizes improve speed
172
+ but require more memory. Default is 64.
173
+ Returns:
174
+ Iterable[float]: An iterable of relevance scores for each document.
175
+ """
176
+
177
+ yield from self._rerank_documents(
178
+ query=query, documents=documents, batch_size=batch_size, **kwargs
179
+ )
180
+
181
+ def rerank_pairs(
182
+ self,
183
+ pairs: Iterable[tuple[str, str]],
184
+ batch_size: int = 64,
185
+ parallel: int | None = None,
186
+ **kwargs: Any,
187
+ ) -> Iterable[float]:
188
+ yield from self._rerank_pairs(
189
+ model_name=self.model_name,
190
+ cache_dir=str(self.cache_dir),
191
+ pairs=pairs,
192
+ batch_size=batch_size,
193
+ parallel=parallel,
194
+ providers=self.providers,
195
+ cuda=self.cuda,
196
+ device_ids=self.device_ids,
197
+ local_files_only=self._local_files_only,
198
+ specific_model_path=self._specific_model_path,
199
+ extra_session_options=self._extra_session_options,
200
+ **kwargs,
201
+ )
202
+
203
+ @classmethod
204
+ def _get_worker_class(cls) -> Type[TextRerankerWorker]:
205
+ return TextCrossEncoderWorker
206
+
207
+ def _post_process_onnx_output(
208
+ self, output: OnnxOutputContext, **kwargs: Any
209
+ ) -> Iterable[float]:
210
+ return (float(elem) for elem in output.model_output)
211
+
212
+ def token_count(
213
+ self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
214
+ ) -> int:
215
+ """Returns the number of tokens in the pairs.
216
+
217
+ Args:
218
+ pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
219
+ batch_size: Batch size for tokenizing
220
+
221
+ Returns:
222
+ token count: overall number of tokens in the pairs
223
+ """
224
+ return self._token_count(pairs, batch_size=batch_size, **kwargs)
225
+
226
+
227
+ class TextCrossEncoderWorker(TextRerankerWorker):
228
+ def init_embedding(
229
+ self,
230
+ model_name: str,
231
+ cache_dir: str,
232
+ **kwargs: Any,
233
+ ) -> OnnxTextCrossEncoder:
234
+ return OnnxTextCrossEncoder(
235
+ model_name=model_name,
236
+ cache_dir=cache_dir,
237
+ threads=1,
238
+ **kwargs,
239
+ )