fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
fastembed/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
from fastembed.bio import ProteinEmbedding
|
|
4
|
+
from fastembed.image import ImageEmbedding
|
|
5
|
+
from fastembed.late_interaction import LateInteractionTextEmbedding
|
|
6
|
+
from fastembed.late_interaction_multimodal import LateInteractionMultimodalEmbedding
|
|
7
|
+
from fastembed.sparse import SparseEmbedding, SparseTextEmbedding
|
|
8
|
+
from fastembed.text import TextEmbedding
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
version = importlib.metadata.version("fastembed")
|
|
12
|
+
except importlib.metadata.PackageNotFoundError as _:
|
|
13
|
+
version = importlib.metadata.version("fastembed-gpu")
|
|
14
|
+
|
|
15
|
+
__version__ = version
|
|
16
|
+
__all__ = [
|
|
17
|
+
"TextEmbedding",
|
|
18
|
+
"SparseTextEmbedding",
|
|
19
|
+
"SparseEmbedding",
|
|
20
|
+
"ImageEmbedding",
|
|
21
|
+
"LateInteractionTextEmbedding",
|
|
22
|
+
"LateInteractionMultimodalEmbedding",
|
|
23
|
+
"ProteinEmbedding",
|
|
24
|
+
]
|
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Iterable, Sequence, Type
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tokenizers import Tokenizer, pre_tokenizers, processors
|
|
8
|
+
from tokenizers.models import WordLevel
|
|
9
|
+
|
|
10
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
11
|
+
from fastembed.common.model_management import ModelManagement
|
|
12
|
+
from fastembed.common.onnx_model import OnnxModel, OnnxOutputContext, EmbeddingWorker
|
|
13
|
+
from fastembed.common.types import NumpyArray, OnnxProvider, Device
|
|
14
|
+
from fastembed.common.utils import define_cache_dir, iter_batch, normalize
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
supported_protein_models: list[DenseModelDescription] = [
|
|
18
|
+
DenseModelDescription(
|
|
19
|
+
model="facebook/esm2_t12_35M_UR50D",
|
|
20
|
+
dim=480,
|
|
21
|
+
description="Protein embeddings, ESM-2 35M parameters, 480 dimensions, 1024 max sequence length",
|
|
22
|
+
license="mit",
|
|
23
|
+
size_in_GB=0.13,
|
|
24
|
+
sources=ModelSource(hf="nleroy917/esm2_t12_35M_UR50D-onnx"),
|
|
25
|
+
model_file="model.onnx",
|
|
26
|
+
additional_files=["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"],
|
|
27
|
+
),
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def load_protein_tokenizer(model_dir: Path, max_length: int = 1024) -> Tokenizer:
|
|
32
|
+
"""Load a protein tokenizer from model directory using HuggingFace tokenizers.
|
|
33
|
+
|
|
34
|
+
Attempts to load in order:
|
|
35
|
+
1. tokenizer.json (standard HuggingFace fast tokenizer format)
|
|
36
|
+
2. Build from vocab.txt (fallback for models without tokenizer.json)
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model_dir: Path to model directory containing tokenizer files
|
|
40
|
+
max_length: Maximum sequence length (default, can be overridden by config)
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Configured Tokenizer instance
|
|
44
|
+
"""
|
|
45
|
+
tokenizer_json_path = model_dir / "tokenizer.json"
|
|
46
|
+
tokenizer_config_path = model_dir / "tokenizer_config.json"
|
|
47
|
+
vocab_path = model_dir / "vocab.txt"
|
|
48
|
+
|
|
49
|
+
# Try to load tokenizer.json directly (preferred)
|
|
50
|
+
if tokenizer_json_path.exists():
|
|
51
|
+
tokenizer = Tokenizer.from_file(str(tokenizer_json_path))
|
|
52
|
+
# Read max_length from config if available
|
|
53
|
+
if tokenizer_config_path.exists():
|
|
54
|
+
with open(tokenizer_config_path) as f:
|
|
55
|
+
config = json.load(f)
|
|
56
|
+
config_max_length = config.get("model_max_length", max_length)
|
|
57
|
+
# Cap at reasonable value (transformers defaults can be huge)
|
|
58
|
+
if config_max_length <= max_length:
|
|
59
|
+
max_length = config_max_length
|
|
60
|
+
tokenizer.enable_truncation(max_length=max_length)
|
|
61
|
+
return tokenizer
|
|
62
|
+
|
|
63
|
+
# Fall back to building from vocab.txt
|
|
64
|
+
if not vocab_path.exists():
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Could not find tokenizer.json or vocab.txt in {model_dir}"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Read max_length from config if available
|
|
70
|
+
if tokenizer_config_path.exists():
|
|
71
|
+
with open(tokenizer_config_path) as f:
|
|
72
|
+
config = json.load(f)
|
|
73
|
+
max_length = config.get("model_max_length", max_length)
|
|
74
|
+
|
|
75
|
+
vocab: dict[str, int] = {}
|
|
76
|
+
with open(vocab_path) as f:
|
|
77
|
+
for idx, line in enumerate(f):
|
|
78
|
+
token = line.strip()
|
|
79
|
+
vocab[token] = idx
|
|
80
|
+
|
|
81
|
+
unk_token = "<unk>"
|
|
82
|
+
cls_token = "<cls>"
|
|
83
|
+
eos_token = "<eos>"
|
|
84
|
+
pad_token = "<pad>"
|
|
85
|
+
|
|
86
|
+
tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token=unk_token))
|
|
87
|
+
|
|
88
|
+
tokenizer.pre_tokenizer = pre_tokenizers.Split(
|
|
89
|
+
pattern="", behavior="isolated", invert=False
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
cls_token_id = vocab.get(cls_token, 0)
|
|
93
|
+
eos_token_id = vocab.get(eos_token, 2)
|
|
94
|
+
|
|
95
|
+
tokenizer.post_processor = processors.TemplateProcessing(
|
|
96
|
+
single=f"{cls_token}:0 $A:0 {eos_token}:0",
|
|
97
|
+
special_tokens=[
|
|
98
|
+
(cls_token, cls_token_id),
|
|
99
|
+
(eos_token, eos_token_id),
|
|
100
|
+
],
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
pad_token_id = vocab.get(pad_token, 1)
|
|
104
|
+
tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token)
|
|
105
|
+
tokenizer.enable_truncation(max_length=max_length)
|
|
106
|
+
|
|
107
|
+
return tokenizer
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class ProteinEmbeddingBase(ModelManagement[DenseModelDescription]):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
model_name: str,
|
|
114
|
+
cache_dir: str | None = None,
|
|
115
|
+
threads: int | None = None,
|
|
116
|
+
**kwargs: Any,
|
|
117
|
+
):
|
|
118
|
+
self.model_name = model_name
|
|
119
|
+
self.cache_dir = cache_dir
|
|
120
|
+
self.threads = threads
|
|
121
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
122
|
+
self._embedding_size: int | None = None
|
|
123
|
+
|
|
124
|
+
def embed(
|
|
125
|
+
self,
|
|
126
|
+
sequences: str | Iterable[str],
|
|
127
|
+
batch_size: int = 32,
|
|
128
|
+
parallel: int | None = None,
|
|
129
|
+
**kwargs: Any,
|
|
130
|
+
) -> Iterable[NumpyArray]:
|
|
131
|
+
"""
|
|
132
|
+
Embed protein sequences.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
sequences: Single protein sequence or iterable of sequences
|
|
136
|
+
batch_size: Batch size for encoding
|
|
137
|
+
parallel: Number of parallel workers (None for single-threaded)
|
|
138
|
+
|
|
139
|
+
Yields:
|
|
140
|
+
Embeddings as numpy arrays
|
|
141
|
+
"""
|
|
142
|
+
raise NotImplementedError()
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
146
|
+
"""
|
|
147
|
+
Returns embedding size of the passed model.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
model_name: Name of the model
|
|
151
|
+
"""
|
|
152
|
+
descriptions = cls._list_supported_models()
|
|
153
|
+
for description in descriptions:
|
|
154
|
+
if description.model.lower() == model_name.lower():
|
|
155
|
+
if description.dim is not None:
|
|
156
|
+
return description.dim
|
|
157
|
+
raise ValueError(f"Model {model_name} not found")
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def embedding_size(self) -> int:
|
|
161
|
+
"""
|
|
162
|
+
Returns embedding size for the current model.
|
|
163
|
+
"""
|
|
164
|
+
if self._embedding_size is None:
|
|
165
|
+
self._embedding_size = self.get_embedding_size(self.model_name)
|
|
166
|
+
return self._embedding_size
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class OnnxProteinModel(OnnxModel[NumpyArray]):
|
|
170
|
+
"""
|
|
171
|
+
ONNX model handler for protein embeddings.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
ONNX_OUTPUT_NAMES: list[str] | None = None
|
|
175
|
+
|
|
176
|
+
def __init__(self) -> None:
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.tokenizer: Tokenizer | None = None
|
|
179
|
+
|
|
180
|
+
def _load_onnx_model(
|
|
181
|
+
self,
|
|
182
|
+
model_dir: Path,
|
|
183
|
+
model_file: str,
|
|
184
|
+
threads: int | None,
|
|
185
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
186
|
+
cuda: bool | Device = Device.AUTO,
|
|
187
|
+
device_id: int | None = None,
|
|
188
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
189
|
+
) -> None:
|
|
190
|
+
super()._load_onnx_model(
|
|
191
|
+
model_dir=model_dir,
|
|
192
|
+
model_file=model_file,
|
|
193
|
+
threads=threads,
|
|
194
|
+
providers=providers,
|
|
195
|
+
cuda=cuda,
|
|
196
|
+
device_id=device_id,
|
|
197
|
+
extra_session_options=extra_session_options,
|
|
198
|
+
)
|
|
199
|
+
self.tokenizer = load_protein_tokenizer(model_dir)
|
|
200
|
+
|
|
201
|
+
def onnx_embed(self, sequences: list[str], **kwargs: Any) -> OnnxOutputContext:
|
|
202
|
+
"""
|
|
203
|
+
Run ONNX inference on protein sequences.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
sequences: List of protein sequences
|
|
207
|
+
Returns:
|
|
208
|
+
OnnxOutputContext containing model output and inputs
|
|
209
|
+
"""
|
|
210
|
+
assert self.tokenizer is not None
|
|
211
|
+
|
|
212
|
+
sequences = [seq.upper() for seq in sequences]
|
|
213
|
+
encoded = self.tokenizer.encode_batch(sequences)
|
|
214
|
+
input_ids = np.array([e.ids for e in encoded], dtype=np.int64)
|
|
215
|
+
attention_mask = np.array([e.attention_mask for e in encoded], dtype=np.int64)
|
|
216
|
+
|
|
217
|
+
input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
|
|
218
|
+
onnx_input: dict[str, NumpyArray] = {
|
|
219
|
+
"input_ids": input_ids,
|
|
220
|
+
}
|
|
221
|
+
if "attention_mask" in input_names:
|
|
222
|
+
onnx_input["attention_mask"] = attention_mask
|
|
223
|
+
|
|
224
|
+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
|
|
225
|
+
|
|
226
|
+
return OnnxOutputContext(
|
|
227
|
+
model_output=model_output[0],
|
|
228
|
+
attention_mask=attention_mask,
|
|
229
|
+
input_ids=input_ids,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def _post_process_onnx_output(
|
|
233
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
234
|
+
) -> Iterable[NumpyArray]:
|
|
235
|
+
"""Convert ONNX output to embeddings with mean pooling."""
|
|
236
|
+
embeddings = output.model_output
|
|
237
|
+
attention_mask = output.attention_mask
|
|
238
|
+
|
|
239
|
+
if attention_mask is None:
|
|
240
|
+
raise ValueError("attention_mask is required for mean pooling")
|
|
241
|
+
|
|
242
|
+
mask_expanded = np.expand_dims(attention_mask, axis=-1)
|
|
243
|
+
sum_embeddings = np.sum(embeddings * mask_expanded, axis=1)
|
|
244
|
+
sum_mask = np.sum(mask_expanded, axis=1)
|
|
245
|
+
sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=None)
|
|
246
|
+
mean_embeddings = sum_embeddings / sum_mask
|
|
247
|
+
|
|
248
|
+
return normalize(mean_embeddings)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class OnnxProteinEmbedding(ProteinEmbeddingBase, OnnxProteinModel):
|
|
252
|
+
"""
|
|
253
|
+
ONNX-based protein embedding implementation.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
258
|
+
return supported_protein_models
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
model_name: str = "facebook/esm2_t12_35M_UR50D",
|
|
263
|
+
cache_dir: str | None = None,
|
|
264
|
+
threads: int | None = None,
|
|
265
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
266
|
+
cuda: bool | Device = Device.AUTO,
|
|
267
|
+
device_ids: list[int] | None = None,
|
|
268
|
+
lazy_load: bool = False,
|
|
269
|
+
device_id: int | None = None,
|
|
270
|
+
specific_model_path: str | None = None,
|
|
271
|
+
**kwargs: Any,
|
|
272
|
+
):
|
|
273
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
274
|
+
self.providers = providers
|
|
275
|
+
self.lazy_load = lazy_load
|
|
276
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
277
|
+
self.device_ids = device_ids
|
|
278
|
+
self.cuda = cuda
|
|
279
|
+
|
|
280
|
+
self.device_id: int | None = None
|
|
281
|
+
if device_id is not None:
|
|
282
|
+
self.device_id = device_id
|
|
283
|
+
elif self.device_ids is not None:
|
|
284
|
+
self.device_id = self.device_ids[0]
|
|
285
|
+
|
|
286
|
+
self.model_description = self._get_model_description(model_name)
|
|
287
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
288
|
+
self._specific_model_path = specific_model_path
|
|
289
|
+
self._model_dir = self.download_model(
|
|
290
|
+
self.model_description,
|
|
291
|
+
self.cache_dir,
|
|
292
|
+
local_files_only=self._local_files_only,
|
|
293
|
+
specific_model_path=self._specific_model_path,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if not self.lazy_load:
|
|
297
|
+
self.load_onnx_model()
|
|
298
|
+
|
|
299
|
+
def load_onnx_model(self) -> None:
|
|
300
|
+
self._load_onnx_model(
|
|
301
|
+
model_dir=self._model_dir,
|
|
302
|
+
model_file=self.model_description.model_file,
|
|
303
|
+
threads=self.threads,
|
|
304
|
+
providers=self.providers,
|
|
305
|
+
cuda=self.cuda,
|
|
306
|
+
device_id=self.device_id,
|
|
307
|
+
extra_session_options=self._extra_session_options,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def embed(
|
|
311
|
+
self,
|
|
312
|
+
sequences: str | Iterable[str],
|
|
313
|
+
batch_size: int = 32,
|
|
314
|
+
parallel: int | None = None,
|
|
315
|
+
**kwargs: Any,
|
|
316
|
+
) -> Iterable[NumpyArray]:
|
|
317
|
+
"""
|
|
318
|
+
Embed protein sequences.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
sequences: Single protein sequence or iterable of sequences (amino acid strings)
|
|
322
|
+
batch_size: Batch size for encoding
|
|
323
|
+
parallel: Number of parallel workers (not yet supported)
|
|
324
|
+
|
|
325
|
+
Yields:
|
|
326
|
+
Embeddings as numpy arrays, one per sequence
|
|
327
|
+
"""
|
|
328
|
+
if isinstance(sequences, str):
|
|
329
|
+
sequences = [sequences]
|
|
330
|
+
|
|
331
|
+
if not hasattr(self, "model") or self.model is None:
|
|
332
|
+
self.load_onnx_model()
|
|
333
|
+
|
|
334
|
+
for batch in iter_batch(sequences, batch_size):
|
|
335
|
+
yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs), **kwargs)
|
|
336
|
+
|
|
337
|
+
@classmethod
|
|
338
|
+
def _get_worker_class(cls) -> Type["ProteinEmbeddingWorker"]:
|
|
339
|
+
return ProteinEmbeddingWorker
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class ProteinEmbeddingWorker(EmbeddingWorker[NumpyArray]):
|
|
343
|
+
def init_embedding(
|
|
344
|
+
self,
|
|
345
|
+
model_name: str,
|
|
346
|
+
cache_dir: str,
|
|
347
|
+
**kwargs: Any,
|
|
348
|
+
) -> OnnxProteinEmbedding:
|
|
349
|
+
return OnnxProteinEmbedding(
|
|
350
|
+
model_name=model_name,
|
|
351
|
+
cache_dir=cache_dir,
|
|
352
|
+
threads=1,
|
|
353
|
+
**kwargs,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
def process(
|
|
357
|
+
self, items: Iterable[tuple[int, Any]]
|
|
358
|
+
) -> Iterable[tuple[int, OnnxOutputContext]]:
|
|
359
|
+
for idx, batch in items:
|
|
360
|
+
onnx_output = self.model.onnx_embed(batch)
|
|
361
|
+
yield idx, onnx_output
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class ProteinEmbedding(ProteinEmbeddingBase):
|
|
365
|
+
"""
|
|
366
|
+
Protein sequence embedding using ESM-2 and similar models.
|
|
367
|
+
|
|
368
|
+
Example:
|
|
369
|
+
>>> from fastembed.bio import ProteinEmbedding
|
|
370
|
+
>>> model = ProteinEmbedding("facebook/esm2_t12_35M_UR50D")
|
|
371
|
+
>>> embeddings = list(model.embed(["MKTVRQERLKS", "GKGDPKKPRGKM"]))
|
|
372
|
+
>>> print(embeddings[0].shape)
|
|
373
|
+
(480,)
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
EMBEDDINGS_REGISTRY: list[Type[ProteinEmbeddingBase]] = [OnnxProteinEmbedding]
|
|
377
|
+
|
|
378
|
+
@classmethod
|
|
379
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
380
|
+
"""Lists the supported models.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
384
|
+
"""
|
|
385
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
386
|
+
|
|
387
|
+
@classmethod
|
|
388
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
389
|
+
result: list[DenseModelDescription] = []
|
|
390
|
+
for embedding in cls.EMBEDDINGS_REGISTRY:
|
|
391
|
+
result.extend(embedding._list_supported_models())
|
|
392
|
+
return result
|
|
393
|
+
|
|
394
|
+
def __init__(
|
|
395
|
+
self,
|
|
396
|
+
model_name: str = "facebook/esm2_t12_35M_UR50D",
|
|
397
|
+
cache_dir: str | None = None,
|
|
398
|
+
threads: int | None = None,
|
|
399
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
400
|
+
cuda: bool | Device = Device.AUTO,
|
|
401
|
+
device_ids: list[int] | None = None,
|
|
402
|
+
lazy_load: bool = False,
|
|
403
|
+
**kwargs: Any,
|
|
404
|
+
):
|
|
405
|
+
"""
|
|
406
|
+
Initialize ProteinEmbedding.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
model_name: Name of the model to use
|
|
410
|
+
cache_dir: Path to cache directory
|
|
411
|
+
threads: Number of threads for ONNX runtime
|
|
412
|
+
providers: ONNX execution providers
|
|
413
|
+
cuda: Whether to use CUDA
|
|
414
|
+
device_ids: List of device IDs for multi-GPU
|
|
415
|
+
lazy_load: Whether to load model lazily
|
|
416
|
+
"""
|
|
417
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
418
|
+
|
|
419
|
+
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
|
|
420
|
+
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
|
|
421
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
422
|
+
self.model = EMBEDDING_MODEL_TYPE(
|
|
423
|
+
model_name=model_name,
|
|
424
|
+
cache_dir=cache_dir,
|
|
425
|
+
threads=threads,
|
|
426
|
+
providers=providers,
|
|
427
|
+
cuda=cuda,
|
|
428
|
+
device_ids=device_ids,
|
|
429
|
+
lazy_load=lazy_load,
|
|
430
|
+
**kwargs,
|
|
431
|
+
)
|
|
432
|
+
return
|
|
433
|
+
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"Model {model_name} is not supported in ProteinEmbedding. "
|
|
436
|
+
"Please check the supported models using `ProteinEmbedding.list_supported_models()`"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
def embed(
|
|
440
|
+
self,
|
|
441
|
+
sequences: str | Iterable[str],
|
|
442
|
+
batch_size: int = 32,
|
|
443
|
+
parallel: int | None = None,
|
|
444
|
+
**kwargs: Any,
|
|
445
|
+
) -> Iterable[NumpyArray]:
|
|
446
|
+
"""Embed protein sequences.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
sequences: Single protein sequence or iterable of sequences (amino acid strings)
|
|
450
|
+
batch_size: Batch size for encoding
|
|
451
|
+
parallel: Number of parallel workers
|
|
452
|
+
|
|
453
|
+
Yields:
|
|
454
|
+
Embeddings as numpy arrays, one per sequence
|
|
455
|
+
"""
|
|
456
|
+
yield from self.model.embed(sequences, batch_size, parallel, **kwargs)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(frozen=True)
|
|
7
|
+
class ModelSource:
|
|
8
|
+
hf: str | None = None
|
|
9
|
+
url: str | None = None
|
|
10
|
+
_deprecated_tar_struct: bool = False
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def deprecated_tar_struct(self) -> bool:
|
|
14
|
+
return self._deprecated_tar_struct
|
|
15
|
+
|
|
16
|
+
def __post_init__(self) -> None:
|
|
17
|
+
if self.hf is None and self.url is None:
|
|
18
|
+
raise ValueError(
|
|
19
|
+
f"At least one source should be set, current sources: hf={self.hf}, url={self.url}"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class BaseModelDescription:
|
|
25
|
+
model: str
|
|
26
|
+
sources: ModelSource
|
|
27
|
+
model_file: str
|
|
28
|
+
description: str
|
|
29
|
+
license: str
|
|
30
|
+
size_in_GB: float
|
|
31
|
+
additional_files: list[str] = field(default_factory=list)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class DenseModelDescription(BaseModelDescription):
|
|
36
|
+
dim: int | None = None
|
|
37
|
+
tasks: dict[str, Any] | None = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
assert self.dim is not None, "dim is required for dense model description"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True)
|
|
44
|
+
class SparseModelDescription(BaseModelDescription):
|
|
45
|
+
requires_idf: bool | None = None
|
|
46
|
+
vocab_size: int | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class PoolingType(str, Enum):
|
|
50
|
+
CLS = "CLS"
|
|
51
|
+
MEAN = "MEAN"
|
|
52
|
+
DISABLED = "DISABLED"
|