fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
fastembed/sparse/bm42.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import string
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Iterable, Sequence, Type
|
|
5
|
+
|
|
6
|
+
import mmh3
|
|
7
|
+
import numpy as np
|
|
8
|
+
from py_rust_stemmers import SnowballStemmer
|
|
9
|
+
|
|
10
|
+
from fastembed.common import OnnxProvider
|
|
11
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
12
|
+
from fastembed.common.types import Device
|
|
13
|
+
from fastembed.common.utils import define_cache_dir
|
|
14
|
+
from fastembed.sparse.sparse_embedding_base import (
|
|
15
|
+
SparseEmbedding,
|
|
16
|
+
SparseTextEmbeddingBase,
|
|
17
|
+
)
|
|
18
|
+
from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker
|
|
19
|
+
from fastembed.common.model_description import SparseModelDescription, ModelSource
|
|
20
|
+
|
|
21
|
+
supported_bm42_models: list[SparseModelDescription] = [
|
|
22
|
+
SparseModelDescription(
|
|
23
|
+
model="Qdrant/bm42-all-minilm-l6-v2-attentions",
|
|
24
|
+
vocab_size=30522,
|
|
25
|
+
description="Light sparse embedding model, which assigns an importance score to each token in the text",
|
|
26
|
+
license="apache-2.0",
|
|
27
|
+
size_in_GB=0.09,
|
|
28
|
+
sources=ModelSource(hf="Qdrant/all_miniLM_L6_v2_with_attentions"),
|
|
29
|
+
model_file="model.onnx",
|
|
30
|
+
additional_files=["stopwords.txt"],
|
|
31
|
+
requires_idf=True,
|
|
32
|
+
),
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_MODEL_TO_LANGUAGE = {
|
|
37
|
+
"Qdrant/bm42-all-minilm-l6-v2-attentions": "english",
|
|
38
|
+
}
|
|
39
|
+
MODEL_TO_LANGUAGE = {
|
|
40
|
+
model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_language_by_model_name(model_name: str) -> str:
|
|
45
|
+
return MODEL_TO_LANGUAGE[model_name.lower()]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
|
|
49
|
+
"""
|
|
50
|
+
Bm42 is an extension of BM25, which tries to better evaluate importance of tokens in the documents,
|
|
51
|
+
by extracting attention weights from the transformer model.
|
|
52
|
+
|
|
53
|
+
Traditional BM25 uses a count of tokens in the document to evaluate the importance of the token,
|
|
54
|
+
but this approach doesn't work well with short documents or chunks of text, as almost all tokens
|
|
55
|
+
there are unique.
|
|
56
|
+
|
|
57
|
+
BM42 addresses this issue by replacing the token count with the attention weights from the transformer model.
|
|
58
|
+
This allows sparse embeddings to work well with short documents, handle rare tokens and leverage traditional NLP
|
|
59
|
+
techniques like stemming and stopwords.
|
|
60
|
+
|
|
61
|
+
WARNING: This model is expected to be used with `modifier="idf"` in the sparse vector index of Qdrant.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
ONNX_OUTPUT_NAMES = ["attention_6"]
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
model_name: str,
|
|
69
|
+
cache_dir: str | None = None,
|
|
70
|
+
threads: int | None = None,
|
|
71
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
72
|
+
alpha: float = 0.5,
|
|
73
|
+
cuda: bool | Device = Device.AUTO,
|
|
74
|
+
device_ids: list[int] | None = None,
|
|
75
|
+
lazy_load: bool = False,
|
|
76
|
+
device_id: int | None = None,
|
|
77
|
+
specific_model_path: str | None = None,
|
|
78
|
+
**kwargs: Any,
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Args:
|
|
82
|
+
model_name (str): The name of the model to use.
|
|
83
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
84
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
85
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
86
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
87
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The providers to use for onnxruntime.
|
|
88
|
+
alpha (float, optional): Parameter, that defines the importance of the token weight in the document
|
|
89
|
+
versus the importance of the token frequency in the corpus. Defaults to 0.5, based on empirical testing.
|
|
90
|
+
It is recommended to only change this parameter based on training data for a specific dataset.
|
|
91
|
+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
92
|
+
Defaults to Device.AUTO.
|
|
93
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
94
|
+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
|
|
95
|
+
with `providers`. Defaults to None.
|
|
96
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
97
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
98
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
99
|
+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
106
|
+
self.providers = providers
|
|
107
|
+
self.lazy_load = lazy_load
|
|
108
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
109
|
+
|
|
110
|
+
# List of device ids, that can be used for data parallel processing in workers
|
|
111
|
+
self.device_ids = device_ids
|
|
112
|
+
self.cuda = cuda
|
|
113
|
+
|
|
114
|
+
# This device_id will be used if we need to load model in current process
|
|
115
|
+
self.device_id: int | None = None
|
|
116
|
+
if device_id is not None:
|
|
117
|
+
self.device_id = device_id
|
|
118
|
+
elif self.device_ids is not None:
|
|
119
|
+
self.device_id = self.device_ids[0]
|
|
120
|
+
|
|
121
|
+
self.model_description = self._get_model_description(model_name)
|
|
122
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
123
|
+
|
|
124
|
+
self._specific_model_path = specific_model_path
|
|
125
|
+
self._model_dir = self.download_model(
|
|
126
|
+
self.model_description,
|
|
127
|
+
self.cache_dir,
|
|
128
|
+
local_files_only=self._local_files_only,
|
|
129
|
+
specific_model_path=self._specific_model_path,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self.invert_vocab: dict[int, str] = {}
|
|
133
|
+
|
|
134
|
+
self.special_tokens: set[str] = set()
|
|
135
|
+
self.special_tokens_ids: set[int] = set()
|
|
136
|
+
self.punctuation = set(string.punctuation)
|
|
137
|
+
self.stopwords = set(self._load_stopwords(self._model_dir))
|
|
138
|
+
self.stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
|
|
139
|
+
self.alpha = alpha
|
|
140
|
+
|
|
141
|
+
if not self.lazy_load:
|
|
142
|
+
self.load_onnx_model()
|
|
143
|
+
|
|
144
|
+
def load_onnx_model(self) -> None:
|
|
145
|
+
self._load_onnx_model(
|
|
146
|
+
model_dir=self._model_dir,
|
|
147
|
+
model_file=self.model_description.model_file,
|
|
148
|
+
threads=self.threads,
|
|
149
|
+
providers=self.providers,
|
|
150
|
+
cuda=self.cuda,
|
|
151
|
+
device_id=self.device_id,
|
|
152
|
+
extra_session_options=self._extra_session_options,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
for token, idx in self.tokenizer.get_vocab().items(): # type: ignore[union-attr]
|
|
156
|
+
self.invert_vocab[idx] = token
|
|
157
|
+
self.special_tokens = set(self.special_token_to_id.keys())
|
|
158
|
+
self.special_tokens_ids = set(self.special_token_to_id.values())
|
|
159
|
+
self.stopwords = set(self._load_stopwords(self._model_dir))
|
|
160
|
+
|
|
161
|
+
def _filter_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
|
|
162
|
+
result: list[tuple[str, Any]] = []
|
|
163
|
+
for token, value in tokens:
|
|
164
|
+
if token in self.stopwords or token in self.punctuation:
|
|
165
|
+
continue
|
|
166
|
+
result.append((token, value))
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
def _stem_pair_tokens(self, tokens: list[tuple[str, Any]]) -> list[tuple[str, Any]]:
|
|
170
|
+
result: list[tuple[str, Any]] = []
|
|
171
|
+
for token, value in tokens:
|
|
172
|
+
processed_token = self.stemmer.stem_word(token)
|
|
173
|
+
result.append((processed_token, value))
|
|
174
|
+
return result
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def _aggregate_weights(
|
|
178
|
+
cls, tokens: list[tuple[str, list[int]]], weights: list[float]
|
|
179
|
+
) -> list[tuple[str, float]]:
|
|
180
|
+
result: list[tuple[str, float]] = []
|
|
181
|
+
for token, idxs in tokens:
|
|
182
|
+
sum_weight = sum(weights[idx] for idx in idxs)
|
|
183
|
+
result.append((token, sum_weight))
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
def _reconstruct_bpe(
|
|
187
|
+
self, bpe_tokens: Iterable[tuple[int, str]]
|
|
188
|
+
) -> list[tuple[str, list[int]]]:
|
|
189
|
+
result: list[tuple[str, list[int]]] = []
|
|
190
|
+
acc: str = ""
|
|
191
|
+
acc_idx: list[int] = []
|
|
192
|
+
|
|
193
|
+
continuing_subword_prefix = self.tokenizer.model.continuing_subword_prefix # type: ignore[union-attr]
|
|
194
|
+
continuing_subword_prefix_len = len(continuing_subword_prefix)
|
|
195
|
+
|
|
196
|
+
for idx, token in bpe_tokens:
|
|
197
|
+
if token in self.special_tokens:
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
if token.startswith(continuing_subword_prefix):
|
|
201
|
+
acc += token[continuing_subword_prefix_len:]
|
|
202
|
+
acc_idx.append(idx)
|
|
203
|
+
else:
|
|
204
|
+
if acc:
|
|
205
|
+
result.append((acc, acc_idx))
|
|
206
|
+
acc_idx = []
|
|
207
|
+
acc = token
|
|
208
|
+
acc_idx.append(idx)
|
|
209
|
+
|
|
210
|
+
if acc:
|
|
211
|
+
result.append((acc, acc_idx))
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
def _rescore_vector(self, vector: dict[str, float]) -> dict[int, float]:
|
|
216
|
+
"""
|
|
217
|
+
Orders all tokens in the vector by their importance and generates a new score based on the importance order.
|
|
218
|
+
So that the scoring doesn't depend on absolute values assigned by the model, but on the relative importance.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
new_vector: dict[int, float] = {}
|
|
222
|
+
|
|
223
|
+
for token, value in vector.items():
|
|
224
|
+
token_id = abs(mmh3.hash(token))
|
|
225
|
+
# Examples:
|
|
226
|
+
# Num 0: Log(1/1 + 1) = 0.6931471805599453
|
|
227
|
+
# Num 1: Log(1/2 + 1) = 0.4054651081081644
|
|
228
|
+
# Num 2: Log(1/3 + 1) = 0.28768207245178085
|
|
229
|
+
new_vector[token_id] = math.log(1.0 + value) ** self.alpha # value
|
|
230
|
+
|
|
231
|
+
return new_vector
|
|
232
|
+
|
|
233
|
+
def _post_process_onnx_output(
|
|
234
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
235
|
+
) -> Iterable[SparseEmbedding]:
|
|
236
|
+
if output.input_ids is None:
|
|
237
|
+
raise ValueError("input_ids must be provided for document post-processing")
|
|
238
|
+
|
|
239
|
+
token_ids_batch = output.input_ids.astype(int)
|
|
240
|
+
|
|
241
|
+
# attention_value shape: (batch_size, num_heads, num_tokens, num_tokens)
|
|
242
|
+
pooled_attention = np.mean(output.model_output[:, :, 0], axis=1) * output.attention_mask
|
|
243
|
+
|
|
244
|
+
for document_token_ids, attention_value in zip(token_ids_batch, pooled_attention):
|
|
245
|
+
document_tokens_with_ids = (
|
|
246
|
+
(idx, self.invert_vocab[token_id])
|
|
247
|
+
for idx, token_id in enumerate(document_token_ids)
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
|
|
251
|
+
|
|
252
|
+
filtered = self._filter_pair_tokens(reconstructed)
|
|
253
|
+
|
|
254
|
+
stemmed = self._stem_pair_tokens(filtered)
|
|
255
|
+
|
|
256
|
+
weighted = self._aggregate_weights(stemmed, attention_value)
|
|
257
|
+
|
|
258
|
+
max_token_weight: dict[str, float] = {}
|
|
259
|
+
|
|
260
|
+
for token, weight in weighted:
|
|
261
|
+
max_token_weight[token] = max(max_token_weight.get(token, 0), weight)
|
|
262
|
+
|
|
263
|
+
rescored = self._rescore_vector(max_token_weight)
|
|
264
|
+
|
|
265
|
+
yield SparseEmbedding.from_dict(rescored)
|
|
266
|
+
|
|
267
|
+
@classmethod
|
|
268
|
+
def _list_supported_models(cls) -> list[SparseModelDescription]:
|
|
269
|
+
"""Lists the supported models.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
list[SparseModelDescription]: A list of SparseModelDescription objects containing the model information.
|
|
273
|
+
"""
|
|
274
|
+
return supported_bm42_models
|
|
275
|
+
|
|
276
|
+
@classmethod
|
|
277
|
+
def _load_stopwords(cls, model_dir: Path) -> list[str]:
|
|
278
|
+
stopwords_path = model_dir / "stopwords.txt"
|
|
279
|
+
if not stopwords_path.exists():
|
|
280
|
+
return []
|
|
281
|
+
|
|
282
|
+
with open(stopwords_path, "r") as f:
|
|
283
|
+
return f.read().splitlines()
|
|
284
|
+
|
|
285
|
+
def embed(
|
|
286
|
+
self,
|
|
287
|
+
documents: str | Iterable[str],
|
|
288
|
+
batch_size: int = 256,
|
|
289
|
+
parallel: int | None = None,
|
|
290
|
+
**kwargs: Any,
|
|
291
|
+
) -> Iterable[SparseEmbedding]:
|
|
292
|
+
"""
|
|
293
|
+
Encode a list of documents into list of embeddings.
|
|
294
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
documents: Iterator of documents or single document to embed
|
|
298
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
299
|
+
parallel:
|
|
300
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
301
|
+
If 0, use all available cores.
|
|
302
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
List of embeddings, one per document
|
|
306
|
+
"""
|
|
307
|
+
yield from self._embed_documents(
|
|
308
|
+
model_name=self.model_name,
|
|
309
|
+
cache_dir=str(self.cache_dir),
|
|
310
|
+
documents=documents,
|
|
311
|
+
batch_size=batch_size,
|
|
312
|
+
parallel=parallel,
|
|
313
|
+
providers=self.providers,
|
|
314
|
+
cuda=self.cuda,
|
|
315
|
+
device_ids=self.device_ids,
|
|
316
|
+
alpha=self.alpha,
|
|
317
|
+
local_files_only=self._local_files_only,
|
|
318
|
+
specific_model_path=self._specific_model_path,
|
|
319
|
+
extra_session_options=self._extra_session_options,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
@classmethod
|
|
323
|
+
def _query_rehash(cls, tokens: Iterable[str]) -> dict[int, float]:
|
|
324
|
+
result: dict[int, float] = {}
|
|
325
|
+
for token in tokens:
|
|
326
|
+
token_id = abs(mmh3.hash(token))
|
|
327
|
+
result[token_id] = 1.0
|
|
328
|
+
return result
|
|
329
|
+
|
|
330
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[SparseEmbedding]:
|
|
331
|
+
"""
|
|
332
|
+
To emulate BM25 behaviour, we don't need to use smart weights in the query, and
|
|
333
|
+
it's enough to just hash the tokens and assign a weight of 1.0 to them.
|
|
334
|
+
It is also faster, as we don't need to run the model for the query.
|
|
335
|
+
"""
|
|
336
|
+
if isinstance(query, str):
|
|
337
|
+
query = [query]
|
|
338
|
+
|
|
339
|
+
if not hasattr(self, "model") or self.model is None:
|
|
340
|
+
self.load_onnx_model()
|
|
341
|
+
|
|
342
|
+
for text in query:
|
|
343
|
+
encoded = self.tokenizer.encode(text) # type: ignore[union-attr]
|
|
344
|
+
document_tokens_with_ids = enumerate(encoded.tokens)
|
|
345
|
+
reconstructed = self._reconstruct_bpe(document_tokens_with_ids)
|
|
346
|
+
filtered = self._filter_pair_tokens(reconstructed)
|
|
347
|
+
stemmed = self._stem_pair_tokens(filtered)
|
|
348
|
+
|
|
349
|
+
yield SparseEmbedding.from_dict(self._query_rehash(token for token, _ in stemmed))
|
|
350
|
+
|
|
351
|
+
@classmethod
|
|
352
|
+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[SparseEmbedding]]:
|
|
353
|
+
return Bm42TextEmbeddingWorker
|
|
354
|
+
|
|
355
|
+
def token_count(
|
|
356
|
+
self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
|
|
357
|
+
) -> int:
|
|
358
|
+
if not hasattr(self, "model") or self.model is None:
|
|
359
|
+
self.load_onnx_model() # loads the tokenizer as well
|
|
360
|
+
return self._token_count(texts, batch_size=batch_size, **kwargs)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class Bm42TextEmbeddingWorker(TextEmbeddingWorker[SparseEmbedding]):
|
|
364
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> Bm42:
|
|
365
|
+
return Bm42(
|
|
366
|
+
model_name=model_name,
|
|
367
|
+
cache_dir=cache_dir,
|
|
368
|
+
**kwargs,
|
|
369
|
+
)
|