fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from dataclasses import asdict
|
|
2
|
+
from typing import Iterable, Any, Type
|
|
3
|
+
|
|
4
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
5
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
6
|
+
from fastembed.common.types import NumpyArray
|
|
7
|
+
from fastembed.late_interaction.late_interaction_embedding_base import (
|
|
8
|
+
LateInteractionTextEmbeddingBase,
|
|
9
|
+
)
|
|
10
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
|
|
11
|
+
from fastembed.text.onnx_text_model import TextEmbeddingWorker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
supported_token_embeddings_models = [
|
|
15
|
+
DenseModelDescription(
|
|
16
|
+
model="jinaai/jina-embeddings-v2-small-en-tokens",
|
|
17
|
+
dim=512,
|
|
18
|
+
description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
|
|
19
|
+
" Prefixes for queries/documents: not necessary, 2023 year.",
|
|
20
|
+
license="apache-2.0",
|
|
21
|
+
size_in_GB=0.12,
|
|
22
|
+
sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
|
|
23
|
+
model_file="onnx/model.onnx",
|
|
24
|
+
),
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
|
|
29
|
+
@classmethod
|
|
30
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
31
|
+
"""Lists the supported models.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
35
|
+
"""
|
|
36
|
+
return supported_token_embeddings_models
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
40
|
+
"""Lists the supported models.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
44
|
+
"""
|
|
45
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
|
|
49
|
+
return TokensEmbeddingWorker
|
|
50
|
+
|
|
51
|
+
def _post_process_onnx_output(
|
|
52
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
53
|
+
) -> Iterable[NumpyArray]:
|
|
54
|
+
# Size: (batch_size, sequence_length, hidden_size)
|
|
55
|
+
embeddings = output.model_output
|
|
56
|
+
# Size: (batch_size, sequence_length)
|
|
57
|
+
assert output.attention_mask is not None
|
|
58
|
+
masks = output.attention_mask
|
|
59
|
+
|
|
60
|
+
# For each document we only select those embeddings that are not masked out
|
|
61
|
+
for i in range(embeddings.shape[0]):
|
|
62
|
+
yield embeddings[i, masks[i] == 1]
|
|
63
|
+
|
|
64
|
+
def embed(
|
|
65
|
+
self,
|
|
66
|
+
documents: str | Iterable[str],
|
|
67
|
+
batch_size: int = 256,
|
|
68
|
+
parallel: int | None = None,
|
|
69
|
+
**kwargs: Any,
|
|
70
|
+
) -> Iterable[NumpyArray]:
|
|
71
|
+
yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TokensEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
|
|
75
|
+
def init_embedding(
|
|
76
|
+
self, model_name: str, cache_dir: str, **kwargs: Any
|
|
77
|
+
) -> TokenEmbeddingsModel:
|
|
78
|
+
return TokenEmbeddingsModel(
|
|
79
|
+
model_name=model_name,
|
|
80
|
+
cache_dir=cache_dir,
|
|
81
|
+
threads=1,
|
|
82
|
+
**kwargs,
|
|
83
|
+
)
|
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import Any, Iterable, Type, Optional, Sequence
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tokenizers import Encoding
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from fastembed.common import ImageInput
|
|
10
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
11
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
12
|
+
from fastembed.common.types import NumpyArray, OnnxProvider
|
|
13
|
+
from fastembed.common.utils import define_cache_dir, iter_batch
|
|
14
|
+
from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
|
|
15
|
+
LateInteractionMultimodalEmbeddingBase,
|
|
16
|
+
)
|
|
17
|
+
from fastembed.late_interaction_multimodal.onnx_multimodal_model import (
|
|
18
|
+
OnnxMultimodalModel,
|
|
19
|
+
TextEmbeddingWorker,
|
|
20
|
+
ImageEmbeddingWorker,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
supported_colmodernvbert_models: list[DenseModelDescription] = [
|
|
24
|
+
DenseModelDescription(
|
|
25
|
+
model="Qdrant/colmodernvbert",
|
|
26
|
+
dim=128,
|
|
27
|
+
description="The late-interaction version of ModernVBERT, CPU friendly, English, 2025.",
|
|
28
|
+
license="mit",
|
|
29
|
+
size_in_GB=1.0,
|
|
30
|
+
sources=ModelSource(hf="Qdrant/colmodernvbert"),
|
|
31
|
+
additional_files=["processor_config.json"],
|
|
32
|
+
model_file="model.onnx",
|
|
33
|
+
),
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ColModernVBERT(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]):
|
|
38
|
+
"""
|
|
39
|
+
The ModernVBERT/colmodernvbert model implementation. This model uses
|
|
40
|
+
bidirectional attention, which proves to work better for retrieval.
|
|
41
|
+
|
|
42
|
+
See: https://huggingface.co/ModernVBERT/colmodernvbert
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
VISUAL_PROMPT_PREFIX = (
|
|
46
|
+
"<|begin_of_text|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
|
|
47
|
+
)
|
|
48
|
+
QUERY_AUGMENTATION_TOKEN = "<end_of_utterance>"
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
model_name: str,
|
|
53
|
+
cache_dir: Optional[str] = None,
|
|
54
|
+
threads: Optional[int] = None,
|
|
55
|
+
providers: Optional[Sequence[OnnxProvider]] = None,
|
|
56
|
+
cuda: bool = False,
|
|
57
|
+
device_ids: Optional[list[int]] = None,
|
|
58
|
+
lazy_load: bool = False,
|
|
59
|
+
device_id: Optional[int] = None,
|
|
60
|
+
specific_model_path: Optional[str] = None,
|
|
61
|
+
**kwargs: Any,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Args:
|
|
65
|
+
model_name (str): The name of the model to use.
|
|
66
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
67
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
68
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
69
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
70
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
|
|
71
|
+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
|
|
72
|
+
cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
73
|
+
Defaults to False.
|
|
74
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
75
|
+
workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
|
|
76
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
77
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
78
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
84
|
+
self.providers = providers
|
|
85
|
+
self.lazy_load = lazy_load
|
|
86
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
87
|
+
|
|
88
|
+
# List of device ids, that can be used for data parallel processing in workers
|
|
89
|
+
self.device_ids = device_ids
|
|
90
|
+
self.cuda = cuda
|
|
91
|
+
|
|
92
|
+
# This device_id will be used if we need to load model in current process
|
|
93
|
+
self.device_id: Optional[int] = None
|
|
94
|
+
if device_id is not None:
|
|
95
|
+
self.device_id = device_id
|
|
96
|
+
elif self.device_ids is not None:
|
|
97
|
+
self.device_id = self.device_ids[0]
|
|
98
|
+
|
|
99
|
+
self.model_description = self._get_model_description(model_name)
|
|
100
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
101
|
+
|
|
102
|
+
self._specific_model_path = specific_model_path
|
|
103
|
+
self._model_dir = self.download_model(
|
|
104
|
+
self.model_description,
|
|
105
|
+
self.cache_dir,
|
|
106
|
+
local_files_only=self._local_files_only,
|
|
107
|
+
specific_model_path=self._specific_model_path,
|
|
108
|
+
)
|
|
109
|
+
self.mask_token_id = None
|
|
110
|
+
self.pad_token_id = None
|
|
111
|
+
self.image_seq_len: Optional[int] = None
|
|
112
|
+
self.max_image_size: Optional[int] = None
|
|
113
|
+
self.image_size: Optional[int] = None
|
|
114
|
+
|
|
115
|
+
if not self.lazy_load:
|
|
116
|
+
self.load_onnx_model()
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
120
|
+
"""Lists the supported models.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
124
|
+
"""
|
|
125
|
+
return supported_colmodernvbert_models
|
|
126
|
+
|
|
127
|
+
def load_onnx_model(self) -> None:
|
|
128
|
+
self._load_onnx_model(
|
|
129
|
+
model_dir=self._model_dir,
|
|
130
|
+
model_file=self.model_description.model_file,
|
|
131
|
+
threads=self.threads,
|
|
132
|
+
providers=self.providers,
|
|
133
|
+
cuda=self.cuda,
|
|
134
|
+
device_id=self.device_id,
|
|
135
|
+
extra_session_options=self._extra_session_options,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Load image processing configuration
|
|
139
|
+
processor_config_path = self._model_dir / "processor_config.json"
|
|
140
|
+
with open(processor_config_path) as f:
|
|
141
|
+
processor_config = json.load(f)
|
|
142
|
+
self.image_seq_len = processor_config.get("image_seq_len", 64)
|
|
143
|
+
|
|
144
|
+
preprocessor_config_path = self._model_dir / "preprocessor_config.json"
|
|
145
|
+
with open(preprocessor_config_path) as f:
|
|
146
|
+
preprocessor_config = json.load(f)
|
|
147
|
+
self.max_image_size = preprocessor_config.get("max_image_size", {}).get(
|
|
148
|
+
"longest_edge", 512
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Load model configuration
|
|
152
|
+
config_path = self._model_dir / "config.json"
|
|
153
|
+
with open(config_path) as f:
|
|
154
|
+
model_config = json.load(f)
|
|
155
|
+
vision_config = model_config.get("vision_config", {})
|
|
156
|
+
self.image_size = vision_config.get("image_size", 512)
|
|
157
|
+
|
|
158
|
+
def _preprocess_onnx_text_input(
|
|
159
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
160
|
+
) -> dict[str, NumpyArray]:
|
|
161
|
+
"""
|
|
162
|
+
Post-process the ONNX model output to convert it into a usable format.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
|
|
169
|
+
"""
|
|
170
|
+
batch_size, seq_length = onnx_input["input_ids"].shape
|
|
171
|
+
empty_image_placeholder: NumpyArray = np.zeros(
|
|
172
|
+
(batch_size, seq_length, 3, self.image_size, self.image_size),
|
|
173
|
+
dtype=np.float32, # type: ignore[type-var,arg-type,assignment]
|
|
174
|
+
)
|
|
175
|
+
onnx_input["pixel_values"] = empty_image_placeholder
|
|
176
|
+
return onnx_input
|
|
177
|
+
|
|
178
|
+
def _post_process_onnx_text_output(
|
|
179
|
+
self,
|
|
180
|
+
output: OnnxOutputContext,
|
|
181
|
+
) -> Iterable[NumpyArray]:
|
|
182
|
+
"""
|
|
183
|
+
Post-process the ONNX model output to convert it into a usable format.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
|
|
190
|
+
"""
|
|
191
|
+
return output.model_output
|
|
192
|
+
|
|
193
|
+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
|
|
194
|
+
# Add query augmentation tokens (matching process_queries logic from colpali-engine)
|
|
195
|
+
augmented_queries = [doc + self.QUERY_AUGMENTATION_TOKEN * 10 for doc in documents]
|
|
196
|
+
encoded = self.tokenizer.encode_batch(augmented_queries) # type: ignore[union-attr]
|
|
197
|
+
return encoded
|
|
198
|
+
|
|
199
|
+
def token_count(
|
|
200
|
+
self,
|
|
201
|
+
texts: str | Iterable[str],
|
|
202
|
+
batch_size: int = 1024,
|
|
203
|
+
include_extension: bool = False,
|
|
204
|
+
**kwargs: Any,
|
|
205
|
+
) -> int:
|
|
206
|
+
if not hasattr(self, "model") or self.model is None:
|
|
207
|
+
self.load_onnx_model() # loads the tokenizer as well
|
|
208
|
+
token_num = 0
|
|
209
|
+
texts = [texts] if isinstance(texts, str) else texts
|
|
210
|
+
assert self.tokenizer is not None
|
|
211
|
+
tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch
|
|
212
|
+
for batch in iter_batch(texts, batch_size):
|
|
213
|
+
token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)])
|
|
214
|
+
return token_num
|
|
215
|
+
|
|
216
|
+
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
|
|
217
|
+
with contextlib.ExitStack() as stack:
|
|
218
|
+
image_files = [
|
|
219
|
+
stack.enter_context(Image.open(image))
|
|
220
|
+
if not isinstance(image, Image.Image)
|
|
221
|
+
else image
|
|
222
|
+
for image in images
|
|
223
|
+
]
|
|
224
|
+
assert self.processor is not None, "Processor is not initialized"
|
|
225
|
+
processed = self.processor(image_files)
|
|
226
|
+
encoded, attention_mask, metadata = self._process_nested_patches(processed) # type: ignore[arg-type]
|
|
227
|
+
|
|
228
|
+
onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask}
|
|
229
|
+
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
|
|
230
|
+
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
|
|
231
|
+
|
|
232
|
+
return OnnxOutputContext(
|
|
233
|
+
model_output=model_output[0],
|
|
234
|
+
attention_mask=attention_mask, # type: ignore[arg-type]
|
|
235
|
+
metadata=metadata,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _process_nested_patches(
|
|
240
|
+
processed: list[list[NumpyArray]],
|
|
241
|
+
) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
|
|
242
|
+
"""
|
|
243
|
+
Process nested image patches (from ImageSplitter).
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
tuple: (encoded array, attention_mask, metadata)
|
|
250
|
+
- encoded: (batch_size, max_patches, C, H, W)
|
|
251
|
+
- attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
|
|
252
|
+
- metadata: Dict with 'patch_counts' key
|
|
253
|
+
"""
|
|
254
|
+
patch_counts = [len(patches) for patches in processed]
|
|
255
|
+
max_patches = max(patch_counts)
|
|
256
|
+
|
|
257
|
+
# Get dimensions from first patch
|
|
258
|
+
channels, height, width = processed[0][0].shape
|
|
259
|
+
batch_size = len(processed)
|
|
260
|
+
|
|
261
|
+
# Create padded array
|
|
262
|
+
encoded = np.zeros(
|
|
263
|
+
(batch_size, max_patches, channels, height, width), dtype=processed[0][0].dtype
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Create attention mask (1 for real patches, 0 for padding)
|
|
267
|
+
attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
|
|
268
|
+
|
|
269
|
+
# Fill in patches and attention mask
|
|
270
|
+
for i, patches in enumerate(processed):
|
|
271
|
+
for j, patch in enumerate(patches):
|
|
272
|
+
encoded[i, j] = patch
|
|
273
|
+
attention_mask[i, j] = 1
|
|
274
|
+
|
|
275
|
+
metadata = {"patch_counts": patch_counts}
|
|
276
|
+
return encoded, attention_mask, metadata # type: ignore[return-value]
|
|
277
|
+
|
|
278
|
+
def _preprocess_onnx_image_input(
|
|
279
|
+
self, onnx_input: dict[str, np.ndarray], **kwargs: Any
|
|
280
|
+
) -> dict[str, NumpyArray]:
|
|
281
|
+
"""
|
|
282
|
+
Add text input placeholders for image data, following Idefics3 processing logic.
|
|
283
|
+
|
|
284
|
+
Constructs input_ids dynamically based on the actual number of image patches,
|
|
285
|
+
using the same token expansion logic as Idefics3Processor.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
onnx_input: Dict with 'pixel_values' (batch, num_patches, C, H, W)
|
|
289
|
+
and 'attention_mask' (batch, num_patches) indicating real patches
|
|
290
|
+
**kwargs: Additional arguments
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Updated onnx_input with 'input_ids' and updated 'attention_mask' for token sequence
|
|
294
|
+
"""
|
|
295
|
+
# The attention_mask in onnx_input has a shape of (batch_size, num_patches),
|
|
296
|
+
# and should be used to create an attention mask matching the input_ids shape.
|
|
297
|
+
patch_attention_mask = onnx_input["attention_mask"]
|
|
298
|
+
pixel_values = onnx_input["pixel_values"]
|
|
299
|
+
|
|
300
|
+
batch_size = pixel_values.shape[0]
|
|
301
|
+
batch_input_ids = []
|
|
302
|
+
|
|
303
|
+
# Build input_ids for each image based on its actual patch count
|
|
304
|
+
for i in range(batch_size):
|
|
305
|
+
# Count real patches (non-padded) from attention mask
|
|
306
|
+
patch_count = int(np.sum(patch_attention_mask[i]))
|
|
307
|
+
|
|
308
|
+
# Compute rows/cols from patch count
|
|
309
|
+
rows, cols = self._compute_rows_cols_from_patches(patch_count)
|
|
310
|
+
|
|
311
|
+
# Build input_ids for this image
|
|
312
|
+
input_ids = self._build_input_ids_for_image(rows, cols)
|
|
313
|
+
batch_input_ids.append(input_ids)
|
|
314
|
+
|
|
315
|
+
# Pad sequences to max length in batch
|
|
316
|
+
max_len = max(len(ids) for ids in batch_input_ids)
|
|
317
|
+
|
|
318
|
+
# Get padding config from tokenizer
|
|
319
|
+
padding_direction = self.tokenizer.padding["direction"] # type: ignore[index,union-attr]
|
|
320
|
+
pad_token_id = self.tokenizer.padding["pad_id"] # type: ignore[index,union-attr]
|
|
321
|
+
|
|
322
|
+
# Initialize with pad token
|
|
323
|
+
padded_input_ids = np.full((batch_size, max_len), pad_token_id, dtype=np.int64)
|
|
324
|
+
attention_mask = np.zeros((batch_size, max_len), dtype=np.int64)
|
|
325
|
+
|
|
326
|
+
for i, input_ids in enumerate(batch_input_ids):
|
|
327
|
+
seq_len = len(input_ids)
|
|
328
|
+
if padding_direction == "left":
|
|
329
|
+
# Left padding: place tokens at the END of the array
|
|
330
|
+
start_idx = max_len - seq_len
|
|
331
|
+
padded_input_ids[i, start_idx:] = input_ids
|
|
332
|
+
attention_mask[i, start_idx:] = 1
|
|
333
|
+
else:
|
|
334
|
+
# Right padding: place tokens at the START of the array
|
|
335
|
+
padded_input_ids[i, :seq_len] = input_ids
|
|
336
|
+
attention_mask[i, :seq_len] = 1
|
|
337
|
+
|
|
338
|
+
onnx_input["input_ids"] = padded_input_ids
|
|
339
|
+
# Update attention_mask with token-level data
|
|
340
|
+
onnx_input["attention_mask"] = attention_mask
|
|
341
|
+
return onnx_input
|
|
342
|
+
|
|
343
|
+
@staticmethod
|
|
344
|
+
def _compute_rows_cols_from_patches(patch_count: int) -> tuple[int, int]:
|
|
345
|
+
if patch_count <= 1:
|
|
346
|
+
return 0, 0
|
|
347
|
+
|
|
348
|
+
# Subtract 1 for the global image
|
|
349
|
+
grid_patches = patch_count - 1
|
|
350
|
+
|
|
351
|
+
# Find rows and cols (assume square or near-square grid)
|
|
352
|
+
rows = int(grid_patches**0.5)
|
|
353
|
+
cols = grid_patches // rows
|
|
354
|
+
|
|
355
|
+
# Verify the calculation
|
|
356
|
+
if rows * cols + 1 != patch_count:
|
|
357
|
+
# Handle non-square grids
|
|
358
|
+
for r in range(1, grid_patches + 1):
|
|
359
|
+
if grid_patches % r == 0:
|
|
360
|
+
c = grid_patches // r
|
|
361
|
+
if r * c + 1 == patch_count:
|
|
362
|
+
return r, c
|
|
363
|
+
# Fallback: treat as unsplit
|
|
364
|
+
return 0, 0
|
|
365
|
+
|
|
366
|
+
return rows, cols
|
|
367
|
+
|
|
368
|
+
def _create_single_image_prompt_string(self) -> str:
|
|
369
|
+
return (
|
|
370
|
+
"<fake_token_around_image>"
|
|
371
|
+
+ "<global-img>"
|
|
372
|
+
+ "<image>" * self.image_seq_len # type: ignore[operator]
|
|
373
|
+
+ "<fake_token_around_image>"
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def _create_split_image_prompt_string(self, rows: int, cols: int) -> str:
|
|
377
|
+
text_split_images = ""
|
|
378
|
+
|
|
379
|
+
# Add tokens for each patch in the grid
|
|
380
|
+
for n_h in range(rows):
|
|
381
|
+
for n_w in range(cols):
|
|
382
|
+
text_split_images += (
|
|
383
|
+
"<fake_token_around_image>"
|
|
384
|
+
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
|
385
|
+
+ "<image>" * self.image_seq_len # type: ignore[operator]
|
|
386
|
+
)
|
|
387
|
+
text_split_images += "\n"
|
|
388
|
+
|
|
389
|
+
# Add global image at the end
|
|
390
|
+
text_split_images += (
|
|
391
|
+
"\n<fake_token_around_image>"
|
|
392
|
+
+ "<global-img>"
|
|
393
|
+
+ "<image>" * self.image_seq_len # type: ignore[operator]
|
|
394
|
+
+ "<fake_token_around_image>"
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return text_split_images
|
|
398
|
+
|
|
399
|
+
def _build_input_ids_for_image(self, rows: int, cols: int) -> np.ndarray:
|
|
400
|
+
# Create the appropriate image prompt string
|
|
401
|
+
if rows == 0 and cols == 0:
|
|
402
|
+
image_prompt_tokens = self._create_single_image_prompt_string()
|
|
403
|
+
else:
|
|
404
|
+
image_prompt_tokens = self._create_split_image_prompt_string(rows, cols)
|
|
405
|
+
|
|
406
|
+
# Replace <image> in visual prompt with expanded tokens
|
|
407
|
+
# The visual prompt is: "<|begin_of_text|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
|
|
408
|
+
expanded_prompt = self.VISUAL_PROMPT_PREFIX.replace("<image>", image_prompt_tokens)
|
|
409
|
+
|
|
410
|
+
# Tokenize the complete prompt
|
|
411
|
+
encoded = self.tokenizer.encode(expanded_prompt) # type: ignore[union-attr]
|
|
412
|
+
|
|
413
|
+
# Convert to numpy array
|
|
414
|
+
return np.array(encoded.ids, dtype=np.int64)
|
|
415
|
+
|
|
416
|
+
def _post_process_onnx_image_output(
|
|
417
|
+
self,
|
|
418
|
+
output: OnnxOutputContext,
|
|
419
|
+
) -> Iterable[NumpyArray]:
|
|
420
|
+
"""
|
|
421
|
+
Post-process the ONNX model output to convert it into a usable format.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
Iterable[NumpyArray]: Post-processed output as NumPy arrays.
|
|
428
|
+
"""
|
|
429
|
+
assert self.model_description.dim is not None, "Model dim is not defined"
|
|
430
|
+
return output.model_output.reshape(
|
|
431
|
+
output.model_output.shape[0], -1, self.model_description.dim
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
def embed_text(
|
|
435
|
+
self,
|
|
436
|
+
documents: str | Iterable[str],
|
|
437
|
+
batch_size: int = 256,
|
|
438
|
+
parallel: Optional[int] = None,
|
|
439
|
+
**kwargs: Any,
|
|
440
|
+
) -> Iterable[NumpyArray]:
|
|
441
|
+
"""
|
|
442
|
+
Encode a list of documents into list of embeddings.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
documents: Iterator of documents or single document to embed
|
|
446
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
447
|
+
parallel:
|
|
448
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
449
|
+
If 0, use all available cores.
|
|
450
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
List of embeddings, one per document
|
|
454
|
+
"""
|
|
455
|
+
yield from self._embed_documents(
|
|
456
|
+
model_name=self.model_name,
|
|
457
|
+
cache_dir=str(self.cache_dir),
|
|
458
|
+
documents=documents,
|
|
459
|
+
batch_size=batch_size,
|
|
460
|
+
parallel=parallel,
|
|
461
|
+
providers=self.providers,
|
|
462
|
+
cuda=self.cuda,
|
|
463
|
+
device_ids=self.device_ids,
|
|
464
|
+
local_files_only=self._local_files_only,
|
|
465
|
+
specific_model_path=self._specific_model_path,
|
|
466
|
+
extra_session_options=self._extra_session_options,
|
|
467
|
+
**kwargs,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def embed_image(
|
|
471
|
+
self,
|
|
472
|
+
images: ImageInput | Iterable[ImageInput],
|
|
473
|
+
batch_size: int = 16,
|
|
474
|
+
parallel: Optional[int] = None,
|
|
475
|
+
**kwargs: Any,
|
|
476
|
+
) -> Iterable[NumpyArray]:
|
|
477
|
+
"""
|
|
478
|
+
Encode a list of images into list of embeddings.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
images: Iterator of image paths or single image path to embed
|
|
482
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
483
|
+
parallel:
|
|
484
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
485
|
+
If 0, use all available cores.
|
|
486
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
List of embeddings, one per document
|
|
490
|
+
"""
|
|
491
|
+
yield from self._embed_images(
|
|
492
|
+
model_name=self.model_name,
|
|
493
|
+
cache_dir=str(self.cache_dir),
|
|
494
|
+
images=images,
|
|
495
|
+
batch_size=batch_size,
|
|
496
|
+
parallel=parallel,
|
|
497
|
+
providers=self.providers,
|
|
498
|
+
cuda=self.cuda,
|
|
499
|
+
device_ids=self.device_ids,
|
|
500
|
+
local_files_only=self._local_files_only,
|
|
501
|
+
specific_model_path=self._specific_model_path,
|
|
502
|
+
extra_session_options=self._extra_session_options,
|
|
503
|
+
**kwargs,
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
@classmethod
|
|
507
|
+
def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
|
|
508
|
+
return ColModernVBERTTextEmbeddingWorker
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]:
|
|
512
|
+
return ColModernVBERTImageEmbeddingWorker
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class ColModernVBERTTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
|
|
516
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT:
|
|
517
|
+
return ColModernVBERT(
|
|
518
|
+
model_name=model_name,
|
|
519
|
+
cache_dir=cache_dir,
|
|
520
|
+
threads=1,
|
|
521
|
+
**kwargs,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class ColModernVBERTImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
|
|
526
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT:
|
|
527
|
+
return ColModernVBERT(
|
|
528
|
+
model_name=model_name,
|
|
529
|
+
cache_dir=cache_dir,
|
|
530
|
+
threads=1,
|
|
531
|
+
**kwargs,
|
|
532
|
+
)
|