fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from multiprocessing import get_all_start_methods
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Iterable, Sequence, Type
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tokenizers import Encoding
|
|
8
|
+
|
|
9
|
+
from fastembed.common.onnx_model import (
|
|
10
|
+
EmbeddingWorker,
|
|
11
|
+
OnnxModel,
|
|
12
|
+
OnnxOutputContext,
|
|
13
|
+
OnnxProvider,
|
|
14
|
+
)
|
|
15
|
+
from fastembed.common.types import NumpyArray, Device
|
|
16
|
+
from fastembed.common.preprocessor_utils import load_tokenizer
|
|
17
|
+
from fastembed.common.utils import iter_batch
|
|
18
|
+
from fastembed.parallel_processor import ParallelWorkerPool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OnnxCrossEncoderModel(OnnxModel[float]):
|
|
22
|
+
ONNX_OUTPUT_NAMES: list[str] | None = None
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def _get_worker_class(cls) -> Type["TextRerankerWorker"]:
|
|
26
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
27
|
+
|
|
28
|
+
def _load_onnx_model(
|
|
29
|
+
self,
|
|
30
|
+
model_dir: Path,
|
|
31
|
+
model_file: str,
|
|
32
|
+
threads: int | None,
|
|
33
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
34
|
+
cuda: bool | Device = Device.AUTO,
|
|
35
|
+
device_id: int | None = None,
|
|
36
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
super()._load_onnx_model(
|
|
39
|
+
model_dir=model_dir,
|
|
40
|
+
model_file=model_file,
|
|
41
|
+
threads=threads,
|
|
42
|
+
providers=providers,
|
|
43
|
+
cuda=cuda,
|
|
44
|
+
device_id=device_id,
|
|
45
|
+
extra_session_options=extra_session_options,
|
|
46
|
+
)
|
|
47
|
+
self.tokenizer, _ = load_tokenizer(model_dir=model_dir)
|
|
48
|
+
assert self.tokenizer is not None
|
|
49
|
+
|
|
50
|
+
def tokenize(self, pairs: list[tuple[str, str]], **_: Any) -> list[Encoding]:
|
|
51
|
+
return self.tokenizer.encode_batch(pairs) # type: ignore[union-attr]
|
|
52
|
+
|
|
53
|
+
def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:
|
|
54
|
+
input_names: set[str] = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
|
|
55
|
+
inputs: dict[str, NumpyArray] = {
|
|
56
|
+
"input_ids": np.array([enc.ids for enc in tokenized_input], dtype=np.int64),
|
|
57
|
+
}
|
|
58
|
+
if "token_type_ids" in input_names:
|
|
59
|
+
inputs["token_type_ids"] = np.array(
|
|
60
|
+
[enc.type_ids for enc in tokenized_input], dtype=np.int64
|
|
61
|
+
)
|
|
62
|
+
if "attention_mask" in input_names:
|
|
63
|
+
inputs["attention_mask"] = np.array(
|
|
64
|
+
[enc.attention_mask for enc in tokenized_input], dtype=np.int64
|
|
65
|
+
)
|
|
66
|
+
return inputs
|
|
67
|
+
|
|
68
|
+
def onnx_embed(self, query: str, documents: list[str], **kwargs: Any) -> OnnxOutputContext:
|
|
69
|
+
pairs = [(query, doc) for doc in documents]
|
|
70
|
+
return self.onnx_embed_pairs(pairs, **kwargs)
|
|
71
|
+
|
|
72
|
+
def onnx_embed_pairs(self, pairs: list[tuple[str, str]], **kwargs: Any) -> OnnxOutputContext:
|
|
73
|
+
tokenized_input = self.tokenize(pairs, **kwargs)
|
|
74
|
+
inputs = self._build_onnx_input(tokenized_input)
|
|
75
|
+
onnx_input = self._preprocess_onnx_input(inputs, **kwargs)
|
|
76
|
+
outputs = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
|
|
77
|
+
relevant_output = outputs[0]
|
|
78
|
+
scores: NumpyArray = relevant_output[:, 0]
|
|
79
|
+
return OnnxOutputContext(model_output=scores)
|
|
80
|
+
|
|
81
|
+
def _rerank_documents(
|
|
82
|
+
self, query: str, documents: Iterable[str], batch_size: int, **kwargs: Any
|
|
83
|
+
) -> Iterable[float]:
|
|
84
|
+
if not hasattr(self, "model") or self.model is None:
|
|
85
|
+
self.load_onnx_model()
|
|
86
|
+
for batch in iter_batch(documents, batch_size):
|
|
87
|
+
yield from self._post_process_onnx_output(self.onnx_embed(query, batch, **kwargs))
|
|
88
|
+
|
|
89
|
+
def _rerank_pairs(
|
|
90
|
+
self,
|
|
91
|
+
model_name: str,
|
|
92
|
+
cache_dir: str,
|
|
93
|
+
pairs: Iterable[tuple[str, str]],
|
|
94
|
+
batch_size: int,
|
|
95
|
+
parallel: int | None = None,
|
|
96
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
97
|
+
cuda: bool | Device = Device.AUTO,
|
|
98
|
+
device_ids: list[int] | None = None,
|
|
99
|
+
local_files_only: bool = False,
|
|
100
|
+
specific_model_path: str | None = None,
|
|
101
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
102
|
+
**kwargs: Any,
|
|
103
|
+
) -> Iterable[float]:
|
|
104
|
+
is_small = False
|
|
105
|
+
|
|
106
|
+
if isinstance(pairs, tuple):
|
|
107
|
+
pairs = [pairs]
|
|
108
|
+
is_small = True
|
|
109
|
+
|
|
110
|
+
if isinstance(pairs, list):
|
|
111
|
+
if len(pairs) < batch_size:
|
|
112
|
+
is_small = True
|
|
113
|
+
|
|
114
|
+
if parallel is None or is_small:
|
|
115
|
+
if not hasattr(self, "model") or self.model is None:
|
|
116
|
+
self.load_onnx_model()
|
|
117
|
+
for batch in iter_batch(pairs, batch_size):
|
|
118
|
+
yield from self._post_process_onnx_output(self.onnx_embed_pairs(batch, **kwargs))
|
|
119
|
+
else:
|
|
120
|
+
if parallel == 0:
|
|
121
|
+
parallel = os.cpu_count()
|
|
122
|
+
|
|
123
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
124
|
+
params = {
|
|
125
|
+
"model_name": model_name,
|
|
126
|
+
"cache_dir": cache_dir,
|
|
127
|
+
"providers": providers,
|
|
128
|
+
"local_files_only": local_files_only,
|
|
129
|
+
"specific_model_path": specific_model_path,
|
|
130
|
+
**kwargs,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
if extra_session_options is not None:
|
|
134
|
+
params.update(extra_session_options)
|
|
135
|
+
|
|
136
|
+
pool = ParallelWorkerPool(
|
|
137
|
+
num_workers=parallel or 1,
|
|
138
|
+
worker=self._get_worker_class(),
|
|
139
|
+
cuda=cuda,
|
|
140
|
+
device_ids=device_ids,
|
|
141
|
+
start_method=start_method,
|
|
142
|
+
)
|
|
143
|
+
for batch in pool.ordered_map(iter_batch(pairs, batch_size), **params):
|
|
144
|
+
yield from self._post_process_onnx_output(batch) # type: ignore
|
|
145
|
+
|
|
146
|
+
def _post_process_onnx_output(
|
|
147
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
148
|
+
) -> Iterable[float]:
|
|
149
|
+
"""Post-process the ONNX model output to convert it into a usable format.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
153
|
+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Iterable[float]: Post-processed output as an iterable of float values.
|
|
157
|
+
"""
|
|
158
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
159
|
+
|
|
160
|
+
def _preprocess_onnx_input(
|
|
161
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
162
|
+
) -> dict[str, NumpyArray]:
|
|
163
|
+
"""
|
|
164
|
+
Preprocess the onnx input.
|
|
165
|
+
"""
|
|
166
|
+
return onnx_input
|
|
167
|
+
|
|
168
|
+
def _token_count(
|
|
169
|
+
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **_: Any
|
|
170
|
+
) -> int:
|
|
171
|
+
if not hasattr(self, "model") or self.model is None:
|
|
172
|
+
self.load_onnx_model() # loads the tokenizer as well
|
|
173
|
+
|
|
174
|
+
token_num = 0
|
|
175
|
+
assert self.tokenizer is not None
|
|
176
|
+
for batch in iter_batch(pairs, batch_size):
|
|
177
|
+
for tokens in self.tokenizer.encode_batch(batch):
|
|
178
|
+
token_num += sum(tokens.attention_mask)
|
|
179
|
+
|
|
180
|
+
return token_num
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class TextRerankerWorker(EmbeddingWorker[float]):
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
model_name: str,
|
|
187
|
+
cache_dir: str,
|
|
188
|
+
**kwargs: Any,
|
|
189
|
+
):
|
|
190
|
+
self.model: OnnxCrossEncoderModel
|
|
191
|
+
super().__init__(model_name, cache_dir, **kwargs)
|
|
192
|
+
|
|
193
|
+
def init_embedding(
|
|
194
|
+
self,
|
|
195
|
+
model_name: str,
|
|
196
|
+
cache_dir: str,
|
|
197
|
+
**kwargs: Any,
|
|
198
|
+
) -> OnnxCrossEncoderModel:
|
|
199
|
+
raise NotImplementedError()
|
|
200
|
+
|
|
201
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
202
|
+
for idx, batch in items:
|
|
203
|
+
onnx_output = self.model.onnx_embed_pairs(batch)
|
|
204
|
+
yield idx, onnx_output
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
|
|
4
|
+
from fastembed.common import OnnxProvider
|
|
5
|
+
from fastembed.common.types import Device
|
|
6
|
+
from fastembed.rerank.cross_encoder.onnx_text_cross_encoder import OnnxTextCrossEncoder
|
|
7
|
+
from fastembed.rerank.cross_encoder.custom_text_cross_encoder import CustomTextCrossEncoder
|
|
8
|
+
|
|
9
|
+
from fastembed.rerank.cross_encoder.text_cross_encoder_base import TextCrossEncoderBase
|
|
10
|
+
from fastembed.common.model_description import (
|
|
11
|
+
ModelSource,
|
|
12
|
+
BaseModelDescription,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TextCrossEncoder(TextCrossEncoderBase):
|
|
17
|
+
CROSS_ENCODER_REGISTRY: list[Type[TextCrossEncoderBase]] = [
|
|
18
|
+
OnnxTextCrossEncoder,
|
|
19
|
+
CustomTextCrossEncoder,
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
24
|
+
"""Lists the supported models.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
list[BaseModelDescription]: A list of dictionaries containing the model information.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
```
|
|
31
|
+
[
|
|
32
|
+
{
|
|
33
|
+
"model": "Xenova/ms-marco-MiniLM-L-6-v2",
|
|
34
|
+
"size_in_GB": 0.08,
|
|
35
|
+
"sources": {
|
|
36
|
+
"hf": "Xenova/ms-marco-MiniLM-L-6-v2",
|
|
37
|
+
},
|
|
38
|
+
"model_file": "onnx/model.onnx",
|
|
39
|
+
"description": "MiniLM-L-6-v2 model optimized for re-ranking tasks.",
|
|
40
|
+
"license": "apache-2.0",
|
|
41
|
+
}
|
|
42
|
+
]
|
|
43
|
+
```
|
|
44
|
+
"""
|
|
45
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def _list_supported_models(cls) -> list[BaseModelDescription]:
|
|
49
|
+
result: list[BaseModelDescription] = []
|
|
50
|
+
for encoder in cls.CROSS_ENCODER_REGISTRY:
|
|
51
|
+
result.extend(encoder._list_supported_models())
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
model_name: str,
|
|
57
|
+
cache_dir: str | None = None,
|
|
58
|
+
threads: int | None = None,
|
|
59
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
60
|
+
cuda: bool | Device = Device.AUTO,
|
|
61
|
+
device_ids: list[int] | None = None,
|
|
62
|
+
lazy_load: bool = False,
|
|
63
|
+
**kwargs: Any,
|
|
64
|
+
):
|
|
65
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
66
|
+
|
|
67
|
+
for CROSS_ENCODER_TYPE in self.CROSS_ENCODER_REGISTRY:
|
|
68
|
+
supported_models = CROSS_ENCODER_TYPE._list_supported_models()
|
|
69
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
70
|
+
self.model = CROSS_ENCODER_TYPE(
|
|
71
|
+
model_name=model_name,
|
|
72
|
+
cache_dir=cache_dir,
|
|
73
|
+
threads=threads,
|
|
74
|
+
providers=providers,
|
|
75
|
+
cuda=cuda,
|
|
76
|
+
device_ids=device_ids,
|
|
77
|
+
lazy_load=lazy_load,
|
|
78
|
+
**kwargs,
|
|
79
|
+
)
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Model {model_name} is not supported in TextCrossEncoder."
|
|
84
|
+
"Please check the supported models using `TextCrossEncoder.list_supported_models()`"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def rerank(
|
|
88
|
+
self, query: str, documents: Iterable[str], batch_size: int = 64, **kwargs: Any
|
|
89
|
+
) -> Iterable[float]:
|
|
90
|
+
"""Rerank a list of documents based on a query.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
query: Query to rerank the documents against
|
|
94
|
+
documents: Iterator of documents to rerank
|
|
95
|
+
batch_size: Batch size for reranking
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Iterable of scores for each document
|
|
99
|
+
"""
|
|
100
|
+
yield from self.model.rerank(query, documents, batch_size=batch_size, **kwargs)
|
|
101
|
+
|
|
102
|
+
def rerank_pairs(
|
|
103
|
+
self,
|
|
104
|
+
pairs: Iterable[tuple[str, str]],
|
|
105
|
+
batch_size: int = 64,
|
|
106
|
+
parallel: int | None = None,
|
|
107
|
+
**kwargs: Any,
|
|
108
|
+
) -> Iterable[float]:
|
|
109
|
+
"""
|
|
110
|
+
Rerank a list of query-document pairs.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
pairs (Iterable[tuple[str, str]]): An iterable of tuples, where each tuple contains a query and a document
|
|
114
|
+
to be scored together.
|
|
115
|
+
batch_size (int, optional): The number of query-document pairs to process in a single batch. Defaults to 64.
|
|
116
|
+
parallel (Optional[int], optional): The number of parallel processes to use for reranking.
|
|
117
|
+
If None, parallelization is disabled. Defaults to None.
|
|
118
|
+
**kwargs (Any): Additional arguments to pass to the underlying reranking model.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Iterable[float]: An iterable of scores corresponding to each query-document pair in the input.
|
|
122
|
+
Higher scores indicate a stronger match between the query and the document.
|
|
123
|
+
|
|
124
|
+
Example:
|
|
125
|
+
>>> encoder = TextCrossEncoder("Xenova/ms-marco-MiniLM-L-6-v2")
|
|
126
|
+
>>> pairs = [("What is AI?", "Artificial intelligence is ..."), ("What is ML?", "Machine learning is ...")]
|
|
127
|
+
>>> scores = list(encoder.rerank_pairs(pairs))
|
|
128
|
+
>>> print(list(map(lambda x: round(x, 2), scores)))
|
|
129
|
+
[-1.24, -10.6]
|
|
130
|
+
"""
|
|
131
|
+
yield from self.model.rerank_pairs(
|
|
132
|
+
pairs, batch_size=batch_size, parallel=parallel, **kwargs
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def add_custom_model(
|
|
137
|
+
cls,
|
|
138
|
+
model: str,
|
|
139
|
+
sources: ModelSource,
|
|
140
|
+
model_file: str = "onnx/model.onnx",
|
|
141
|
+
description: str = "",
|
|
142
|
+
license: str = "",
|
|
143
|
+
size_in_gb: float = 0.0,
|
|
144
|
+
additional_files: list[str] | None = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
registered_models = cls._list_supported_models()
|
|
147
|
+
for registered_model in registered_models:
|
|
148
|
+
if model == registered_model.model:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Model {model} is already registered in CrossEncoderModel, if you still want to add this model, "
|
|
151
|
+
f"please use another model name"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
CustomTextCrossEncoder.add_model(
|
|
155
|
+
BaseModelDescription(
|
|
156
|
+
model=model,
|
|
157
|
+
sources=sources,
|
|
158
|
+
model_file=model_file,
|
|
159
|
+
description=description,
|
|
160
|
+
license=license,
|
|
161
|
+
size_in_GB=size_in_gb,
|
|
162
|
+
additional_files=additional_files or [],
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def token_count(
|
|
167
|
+
self, pairs: Iterable[tuple[str, str]], batch_size: int = 1024, **kwargs: Any
|
|
168
|
+
) -> int:
|
|
169
|
+
"""Returns the number of tokens in the pairs.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
pairs: Iterable of tuples, where each tuple contains a query and a document to be tokenized
|
|
173
|
+
batch_size: Batch size for tokenizing
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
token count: overall number of tokens in the pairs
|
|
177
|
+
"""
|
|
178
|
+
return self.model.token_count(pairs, batch_size=batch_size, **kwargs)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import Any, Iterable
|
|
2
|
+
|
|
3
|
+
from fastembed.common.model_description import BaseModelDescription
|
|
4
|
+
from fastembed.common.model_management import ModelManagement
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TextCrossEncoderBase(ModelManagement[BaseModelDescription]):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
model_name: str,
|
|
11
|
+
cache_dir: str | None = None,
|
|
12
|
+
threads: int | None = None,
|
|
13
|
+
**kwargs: Any,
|
|
14
|
+
):
|
|
15
|
+
self.model_name = model_name
|
|
16
|
+
self.cache_dir = cache_dir
|
|
17
|
+
self.threads = threads
|
|
18
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
19
|
+
|
|
20
|
+
def rerank(
|
|
21
|
+
self,
|
|
22
|
+
query: str,
|
|
23
|
+
documents: Iterable[str],
|
|
24
|
+
batch_size: int = 64,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
) -> Iterable[float]:
|
|
27
|
+
"""Rerank a list of documents given a query.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
query (str): The query to rerank the documents.
|
|
31
|
+
documents (Iterable[str]): The list of texts to rerank.
|
|
32
|
+
batch_size (int): The batch size to use for reranking.
|
|
33
|
+
**kwargs: Additional keyword argument to pass to the rerank method.
|
|
34
|
+
|
|
35
|
+
Yields:
|
|
36
|
+
Iterable[float]: The scores of the reranked the documents.
|
|
37
|
+
"""
|
|
38
|
+
raise NotImplementedError("This method should be overridden by subclasses")
|
|
39
|
+
|
|
40
|
+
def rerank_pairs(
|
|
41
|
+
self,
|
|
42
|
+
pairs: Iterable[tuple[str, str]],
|
|
43
|
+
batch_size: int = 64,
|
|
44
|
+
parallel: int | None = None,
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> Iterable[float]:
|
|
47
|
+
"""Rerank query-document pairs.
|
|
48
|
+
Args:
|
|
49
|
+
pairs (Iterable[tuple[str, str]]): Query-document pairs to rerank
|
|
50
|
+
batch_size (int): The batch size to use for reranking.
|
|
51
|
+
parallel: parallel:
|
|
52
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
53
|
+
If 0, use all available cores.
|
|
54
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
55
|
+
**kwargs: Additional keyword argument to pass to the rerank method.
|
|
56
|
+
Yields:
|
|
57
|
+
Iterable[float]: Scores for each individual pair
|
|
58
|
+
"""
|
|
59
|
+
raise NotImplementedError("This method should be overridden by subclasses")
|
|
60
|
+
|
|
61
|
+
def token_count(self, pairs: Iterable[tuple[str, str]], **kwargs: Any) -> int:
|
|
62
|
+
"""Returns the number of tokens in the pairs."""
|
|
63
|
+
raise NotImplementedError("This method should be overridden by subclasses")
|