fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,83 @@
1
+ from dataclasses import asdict
2
+ from typing import Iterable, Any, Type
3
+
4
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
5
+ from fastembed.common.onnx_model import OnnxOutputContext
6
+ from fastembed.common.types import NumpyArray
7
+ from fastembed.late_interaction.late_interaction_embedding_base import (
8
+ LateInteractionTextEmbeddingBase,
9
+ )
10
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding
11
+ from fastembed.text.onnx_text_model import TextEmbeddingWorker
12
+
13
+
14
+ supported_token_embeddings_models = [
15
+ DenseModelDescription(
16
+ model="jinaai/jina-embeddings-v2-small-en-tokens",
17
+ dim=512,
18
+ description="Text embeddings, Unimodal (text), English, 8192 input tokens truncation,"
19
+ " Prefixes for queries/documents: not necessary, 2023 year.",
20
+ license="apache-2.0",
21
+ size_in_GB=0.12,
22
+ sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
23
+ model_file="onnx/model.onnx",
24
+ ),
25
+ ]
26
+
27
+
28
+ class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
29
+ @classmethod
30
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
31
+ """Lists the supported models.
32
+
33
+ Returns:
34
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
35
+ """
36
+ return supported_token_embeddings_models
37
+
38
+ @classmethod
39
+ def list_supported_models(cls) -> list[dict[str, Any]]:
40
+ """Lists the supported models.
41
+
42
+ Returns:
43
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
44
+ """
45
+ return [asdict(model) for model in cls._list_supported_models()]
46
+
47
+ @classmethod
48
+ def _get_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
49
+ return TokensEmbeddingWorker
50
+
51
+ def _post_process_onnx_output(
52
+ self, output: OnnxOutputContext, **kwargs: Any
53
+ ) -> Iterable[NumpyArray]:
54
+ # Size: (batch_size, sequence_length, hidden_size)
55
+ embeddings = output.model_output
56
+ # Size: (batch_size, sequence_length)
57
+ assert output.attention_mask is not None
58
+ masks = output.attention_mask
59
+
60
+ # For each document we only select those embeddings that are not masked out
61
+ for i in range(embeddings.shape[0]):
62
+ yield embeddings[i, masks[i] == 1]
63
+
64
+ def embed(
65
+ self,
66
+ documents: str | Iterable[str],
67
+ batch_size: int = 256,
68
+ parallel: int | None = None,
69
+ **kwargs: Any,
70
+ ) -> Iterable[NumpyArray]:
71
+ yield from super().embed(documents, batch_size=batch_size, parallel=parallel, **kwargs)
72
+
73
+
74
+ class TokensEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
75
+ def init_embedding(
76
+ self, model_name: str, cache_dir: str, **kwargs: Any
77
+ ) -> TokenEmbeddingsModel:
78
+ return TokenEmbeddingsModel(
79
+ model_name=model_name,
80
+ cache_dir=cache_dir,
81
+ threads=1,
82
+ **kwargs,
83
+ )
@@ -0,0 +1,5 @@
1
+ from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding import (
2
+ LateInteractionMultimodalEmbedding,
3
+ )
4
+
5
+ __all__ = ["LateInteractionMultimodalEmbedding"]
@@ -0,0 +1,532 @@
1
+ import contextlib
2
+ from typing import Any, Iterable, Type, Optional, Sequence
3
+ import json
4
+
5
+ import numpy as np
6
+ from tokenizers import Encoding
7
+ from PIL import Image
8
+
9
+ from fastembed.common import ImageInput
10
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
11
+ from fastembed.common.onnx_model import OnnxOutputContext
12
+ from fastembed.common.types import NumpyArray, OnnxProvider
13
+ from fastembed.common.utils import define_cache_dir, iter_batch
14
+ from fastembed.late_interaction_multimodal.late_interaction_multimodal_embedding_base import (
15
+ LateInteractionMultimodalEmbeddingBase,
16
+ )
17
+ from fastembed.late_interaction_multimodal.onnx_multimodal_model import (
18
+ OnnxMultimodalModel,
19
+ TextEmbeddingWorker,
20
+ ImageEmbeddingWorker,
21
+ )
22
+
23
+ supported_colmodernvbert_models: list[DenseModelDescription] = [
24
+ DenseModelDescription(
25
+ model="Qdrant/colmodernvbert",
26
+ dim=128,
27
+ description="The late-interaction version of ModernVBERT, CPU friendly, English, 2025.",
28
+ license="mit",
29
+ size_in_GB=1.0,
30
+ sources=ModelSource(hf="Qdrant/colmodernvbert"),
31
+ additional_files=["processor_config.json"],
32
+ model_file="model.onnx",
33
+ ),
34
+ ]
35
+
36
+
37
+ class ColModernVBERT(LateInteractionMultimodalEmbeddingBase, OnnxMultimodalModel[NumpyArray]):
38
+ """
39
+ The ModernVBERT/colmodernvbert model implementation. This model uses
40
+ bidirectional attention, which proves to work better for retrieval.
41
+
42
+ See: https://huggingface.co/ModernVBERT/colmodernvbert
43
+ """
44
+
45
+ VISUAL_PROMPT_PREFIX = (
46
+ "<|begin_of_text|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
47
+ )
48
+ QUERY_AUGMENTATION_TOKEN = "<end_of_utterance>"
49
+
50
+ def __init__(
51
+ self,
52
+ model_name: str,
53
+ cache_dir: Optional[str] = None,
54
+ threads: Optional[int] = None,
55
+ providers: Optional[Sequence[OnnxProvider]] = None,
56
+ cuda: bool = False,
57
+ device_ids: Optional[list[int]] = None,
58
+ lazy_load: bool = False,
59
+ device_id: Optional[int] = None,
60
+ specific_model_path: Optional[str] = None,
61
+ **kwargs: Any,
62
+ ):
63
+ """
64
+ Args:
65
+ model_name (str): The name of the model to use.
66
+ cache_dir (str, optional): The path to the cache directory.
67
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
68
+ Defaults to `fastembed_cache` in the system's temp directory.
69
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
70
+ providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
71
+ Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
72
+ cuda (bool, optional): Whether to use cuda for inference. Mutually exclusive with `providers`
73
+ Defaults to False.
74
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
75
+ workers. Should be used with `cuda=True`, mutually exclusive with `providers`. Defaults to None.
76
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
77
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
78
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
79
+
80
+ Raises:
81
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
82
+ """
83
+ super().__init__(model_name, cache_dir, threads, **kwargs)
84
+ self.providers = providers
85
+ self.lazy_load = lazy_load
86
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
87
+
88
+ # List of device ids, that can be used for data parallel processing in workers
89
+ self.device_ids = device_ids
90
+ self.cuda = cuda
91
+
92
+ # This device_id will be used if we need to load model in current process
93
+ self.device_id: Optional[int] = None
94
+ if device_id is not None:
95
+ self.device_id = device_id
96
+ elif self.device_ids is not None:
97
+ self.device_id = self.device_ids[0]
98
+
99
+ self.model_description = self._get_model_description(model_name)
100
+ self.cache_dir = str(define_cache_dir(cache_dir))
101
+
102
+ self._specific_model_path = specific_model_path
103
+ self._model_dir = self.download_model(
104
+ self.model_description,
105
+ self.cache_dir,
106
+ local_files_only=self._local_files_only,
107
+ specific_model_path=self._specific_model_path,
108
+ )
109
+ self.mask_token_id = None
110
+ self.pad_token_id = None
111
+ self.image_seq_len: Optional[int] = None
112
+ self.max_image_size: Optional[int] = None
113
+ self.image_size: Optional[int] = None
114
+
115
+ if not self.lazy_load:
116
+ self.load_onnx_model()
117
+
118
+ @classmethod
119
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
120
+ """Lists the supported models.
121
+
122
+ Returns:
123
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
124
+ """
125
+ return supported_colmodernvbert_models
126
+
127
+ def load_onnx_model(self) -> None:
128
+ self._load_onnx_model(
129
+ model_dir=self._model_dir,
130
+ model_file=self.model_description.model_file,
131
+ threads=self.threads,
132
+ providers=self.providers,
133
+ cuda=self.cuda,
134
+ device_id=self.device_id,
135
+ extra_session_options=self._extra_session_options,
136
+ )
137
+
138
+ # Load image processing configuration
139
+ processor_config_path = self._model_dir / "processor_config.json"
140
+ with open(processor_config_path) as f:
141
+ processor_config = json.load(f)
142
+ self.image_seq_len = processor_config.get("image_seq_len", 64)
143
+
144
+ preprocessor_config_path = self._model_dir / "preprocessor_config.json"
145
+ with open(preprocessor_config_path) as f:
146
+ preprocessor_config = json.load(f)
147
+ self.max_image_size = preprocessor_config.get("max_image_size", {}).get(
148
+ "longest_edge", 512
149
+ )
150
+
151
+ # Load model configuration
152
+ config_path = self._model_dir / "config.json"
153
+ with open(config_path) as f:
154
+ model_config = json.load(f)
155
+ vision_config = model_config.get("vision_config", {})
156
+ self.image_size = vision_config.get("image_size", 512)
157
+
158
+ def _preprocess_onnx_text_input(
159
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
160
+ ) -> dict[str, NumpyArray]:
161
+ """
162
+ Post-process the ONNX model output to convert it into a usable format.
163
+
164
+ Args:
165
+ output (OnnxOutputContext): The raw output from the ONNX model.
166
+
167
+ Returns:
168
+ Iterable[NumpyArray]: Post-processed output as NumPy arrays.
169
+ """
170
+ batch_size, seq_length = onnx_input["input_ids"].shape
171
+ empty_image_placeholder: NumpyArray = np.zeros(
172
+ (batch_size, seq_length, 3, self.image_size, self.image_size),
173
+ dtype=np.float32, # type: ignore[type-var,arg-type,assignment]
174
+ )
175
+ onnx_input["pixel_values"] = empty_image_placeholder
176
+ return onnx_input
177
+
178
+ def _post_process_onnx_text_output(
179
+ self,
180
+ output: OnnxOutputContext,
181
+ ) -> Iterable[NumpyArray]:
182
+ """
183
+ Post-process the ONNX model output to convert it into a usable format.
184
+
185
+ Args:
186
+ output (OnnxOutputContext): The raw output from the ONNX model.
187
+
188
+ Returns:
189
+ Iterable[NumpyArray]: Post-processed output as NumPy arrays.
190
+ """
191
+ return output.model_output
192
+
193
+ def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
194
+ # Add query augmentation tokens (matching process_queries logic from colpali-engine)
195
+ augmented_queries = [doc + self.QUERY_AUGMENTATION_TOKEN * 10 for doc in documents]
196
+ encoded = self.tokenizer.encode_batch(augmented_queries) # type: ignore[union-attr]
197
+ return encoded
198
+
199
+ def token_count(
200
+ self,
201
+ texts: str | Iterable[str],
202
+ batch_size: int = 1024,
203
+ include_extension: bool = False,
204
+ **kwargs: Any,
205
+ ) -> int:
206
+ if not hasattr(self, "model") or self.model is None:
207
+ self.load_onnx_model() # loads the tokenizer as well
208
+ token_num = 0
209
+ texts = [texts] if isinstance(texts, str) else texts
210
+ assert self.tokenizer is not None
211
+ tokenize_func = self.tokenize if include_extension else self.tokenizer.encode_batch
212
+ for batch in iter_batch(texts, batch_size):
213
+ token_num += sum([sum(encoding.attention_mask) for encoding in tokenize_func(batch)])
214
+ return token_num
215
+
216
+ def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
217
+ with contextlib.ExitStack() as stack:
218
+ image_files = [
219
+ stack.enter_context(Image.open(image))
220
+ if not isinstance(image, Image.Image)
221
+ else image
222
+ for image in images
223
+ ]
224
+ assert self.processor is not None, "Processor is not initialized"
225
+ processed = self.processor(image_files)
226
+ encoded, attention_mask, metadata = self._process_nested_patches(processed) # type: ignore[arg-type]
227
+
228
+ onnx_input = {"pixel_values": encoded, "attention_mask": attention_mask}
229
+ onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
230
+ model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
231
+
232
+ return OnnxOutputContext(
233
+ model_output=model_output[0],
234
+ attention_mask=attention_mask, # type: ignore[arg-type]
235
+ metadata=metadata,
236
+ )
237
+
238
+ @staticmethod
239
+ def _process_nested_patches(
240
+ processed: list[list[NumpyArray]],
241
+ ) -> tuple[NumpyArray, NumpyArray, dict[str, Any]]:
242
+ """
243
+ Process nested image patches (from ImageSplitter).
244
+
245
+ Args:
246
+ processed: List of patch lists, one per image [[img1_patches], [img2_patches], ...]
247
+
248
+ Returns:
249
+ tuple: (encoded array, attention_mask, metadata)
250
+ - encoded: (batch_size, max_patches, C, H, W)
251
+ - attention_mask: (batch_size, max_patches) with 1 for real patches, 0 for padding
252
+ - metadata: Dict with 'patch_counts' key
253
+ """
254
+ patch_counts = [len(patches) for patches in processed]
255
+ max_patches = max(patch_counts)
256
+
257
+ # Get dimensions from first patch
258
+ channels, height, width = processed[0][0].shape
259
+ batch_size = len(processed)
260
+
261
+ # Create padded array
262
+ encoded = np.zeros(
263
+ (batch_size, max_patches, channels, height, width), dtype=processed[0][0].dtype
264
+ )
265
+
266
+ # Create attention mask (1 for real patches, 0 for padding)
267
+ attention_mask = np.zeros((batch_size, max_patches), dtype=np.int64)
268
+
269
+ # Fill in patches and attention mask
270
+ for i, patches in enumerate(processed):
271
+ for j, patch in enumerate(patches):
272
+ encoded[i, j] = patch
273
+ attention_mask[i, j] = 1
274
+
275
+ metadata = {"patch_counts": patch_counts}
276
+ return encoded, attention_mask, metadata # type: ignore[return-value]
277
+
278
+ def _preprocess_onnx_image_input(
279
+ self, onnx_input: dict[str, np.ndarray], **kwargs: Any
280
+ ) -> dict[str, NumpyArray]:
281
+ """
282
+ Add text input placeholders for image data, following Idefics3 processing logic.
283
+
284
+ Constructs input_ids dynamically based on the actual number of image patches,
285
+ using the same token expansion logic as Idefics3Processor.
286
+
287
+ Args:
288
+ onnx_input: Dict with 'pixel_values' (batch, num_patches, C, H, W)
289
+ and 'attention_mask' (batch, num_patches) indicating real patches
290
+ **kwargs: Additional arguments
291
+
292
+ Returns:
293
+ Updated onnx_input with 'input_ids' and updated 'attention_mask' for token sequence
294
+ """
295
+ # The attention_mask in onnx_input has a shape of (batch_size, num_patches),
296
+ # and should be used to create an attention mask matching the input_ids shape.
297
+ patch_attention_mask = onnx_input["attention_mask"]
298
+ pixel_values = onnx_input["pixel_values"]
299
+
300
+ batch_size = pixel_values.shape[0]
301
+ batch_input_ids = []
302
+
303
+ # Build input_ids for each image based on its actual patch count
304
+ for i in range(batch_size):
305
+ # Count real patches (non-padded) from attention mask
306
+ patch_count = int(np.sum(patch_attention_mask[i]))
307
+
308
+ # Compute rows/cols from patch count
309
+ rows, cols = self._compute_rows_cols_from_patches(patch_count)
310
+
311
+ # Build input_ids for this image
312
+ input_ids = self._build_input_ids_for_image(rows, cols)
313
+ batch_input_ids.append(input_ids)
314
+
315
+ # Pad sequences to max length in batch
316
+ max_len = max(len(ids) for ids in batch_input_ids)
317
+
318
+ # Get padding config from tokenizer
319
+ padding_direction = self.tokenizer.padding["direction"] # type: ignore[index,union-attr]
320
+ pad_token_id = self.tokenizer.padding["pad_id"] # type: ignore[index,union-attr]
321
+
322
+ # Initialize with pad token
323
+ padded_input_ids = np.full((batch_size, max_len), pad_token_id, dtype=np.int64)
324
+ attention_mask = np.zeros((batch_size, max_len), dtype=np.int64)
325
+
326
+ for i, input_ids in enumerate(batch_input_ids):
327
+ seq_len = len(input_ids)
328
+ if padding_direction == "left":
329
+ # Left padding: place tokens at the END of the array
330
+ start_idx = max_len - seq_len
331
+ padded_input_ids[i, start_idx:] = input_ids
332
+ attention_mask[i, start_idx:] = 1
333
+ else:
334
+ # Right padding: place tokens at the START of the array
335
+ padded_input_ids[i, :seq_len] = input_ids
336
+ attention_mask[i, :seq_len] = 1
337
+
338
+ onnx_input["input_ids"] = padded_input_ids
339
+ # Update attention_mask with token-level data
340
+ onnx_input["attention_mask"] = attention_mask
341
+ return onnx_input
342
+
343
+ @staticmethod
344
+ def _compute_rows_cols_from_patches(patch_count: int) -> tuple[int, int]:
345
+ if patch_count <= 1:
346
+ return 0, 0
347
+
348
+ # Subtract 1 for the global image
349
+ grid_patches = patch_count - 1
350
+
351
+ # Find rows and cols (assume square or near-square grid)
352
+ rows = int(grid_patches**0.5)
353
+ cols = grid_patches // rows
354
+
355
+ # Verify the calculation
356
+ if rows * cols + 1 != patch_count:
357
+ # Handle non-square grids
358
+ for r in range(1, grid_patches + 1):
359
+ if grid_patches % r == 0:
360
+ c = grid_patches // r
361
+ if r * c + 1 == patch_count:
362
+ return r, c
363
+ # Fallback: treat as unsplit
364
+ return 0, 0
365
+
366
+ return rows, cols
367
+
368
+ def _create_single_image_prompt_string(self) -> str:
369
+ return (
370
+ "<fake_token_around_image>"
371
+ + "<global-img>"
372
+ + "<image>" * self.image_seq_len # type: ignore[operator]
373
+ + "<fake_token_around_image>"
374
+ )
375
+
376
+ def _create_split_image_prompt_string(self, rows: int, cols: int) -> str:
377
+ text_split_images = ""
378
+
379
+ # Add tokens for each patch in the grid
380
+ for n_h in range(rows):
381
+ for n_w in range(cols):
382
+ text_split_images += (
383
+ "<fake_token_around_image>"
384
+ + f"<row_{n_h + 1}_col_{n_w + 1}>"
385
+ + "<image>" * self.image_seq_len # type: ignore[operator]
386
+ )
387
+ text_split_images += "\n"
388
+
389
+ # Add global image at the end
390
+ text_split_images += (
391
+ "\n<fake_token_around_image>"
392
+ + "<global-img>"
393
+ + "<image>" * self.image_seq_len # type: ignore[operator]
394
+ + "<fake_token_around_image>"
395
+ )
396
+
397
+ return text_split_images
398
+
399
+ def _build_input_ids_for_image(self, rows: int, cols: int) -> np.ndarray:
400
+ # Create the appropriate image prompt string
401
+ if rows == 0 and cols == 0:
402
+ image_prompt_tokens = self._create_single_image_prompt_string()
403
+ else:
404
+ image_prompt_tokens = self._create_split_image_prompt_string(rows, cols)
405
+
406
+ # Replace <image> in visual prompt with expanded tokens
407
+ # The visual prompt is: "<|begin_of_text|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
408
+ expanded_prompt = self.VISUAL_PROMPT_PREFIX.replace("<image>", image_prompt_tokens)
409
+
410
+ # Tokenize the complete prompt
411
+ encoded = self.tokenizer.encode(expanded_prompt) # type: ignore[union-attr]
412
+
413
+ # Convert to numpy array
414
+ return np.array(encoded.ids, dtype=np.int64)
415
+
416
+ def _post_process_onnx_image_output(
417
+ self,
418
+ output: OnnxOutputContext,
419
+ ) -> Iterable[NumpyArray]:
420
+ """
421
+ Post-process the ONNX model output to convert it into a usable format.
422
+
423
+ Args:
424
+ output (OnnxOutputContext): The raw output from the ONNX model.
425
+
426
+ Returns:
427
+ Iterable[NumpyArray]: Post-processed output as NumPy arrays.
428
+ """
429
+ assert self.model_description.dim is not None, "Model dim is not defined"
430
+ return output.model_output.reshape(
431
+ output.model_output.shape[0], -1, self.model_description.dim
432
+ )
433
+
434
+ def embed_text(
435
+ self,
436
+ documents: str | Iterable[str],
437
+ batch_size: int = 256,
438
+ parallel: Optional[int] = None,
439
+ **kwargs: Any,
440
+ ) -> Iterable[NumpyArray]:
441
+ """
442
+ Encode a list of documents into list of embeddings.
443
+
444
+ Args:
445
+ documents: Iterator of documents or single document to embed
446
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
447
+ parallel:
448
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
449
+ If 0, use all available cores.
450
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
451
+
452
+ Returns:
453
+ List of embeddings, one per document
454
+ """
455
+ yield from self._embed_documents(
456
+ model_name=self.model_name,
457
+ cache_dir=str(self.cache_dir),
458
+ documents=documents,
459
+ batch_size=batch_size,
460
+ parallel=parallel,
461
+ providers=self.providers,
462
+ cuda=self.cuda,
463
+ device_ids=self.device_ids,
464
+ local_files_only=self._local_files_only,
465
+ specific_model_path=self._specific_model_path,
466
+ extra_session_options=self._extra_session_options,
467
+ **kwargs,
468
+ )
469
+
470
+ def embed_image(
471
+ self,
472
+ images: ImageInput | Iterable[ImageInput],
473
+ batch_size: int = 16,
474
+ parallel: Optional[int] = None,
475
+ **kwargs: Any,
476
+ ) -> Iterable[NumpyArray]:
477
+ """
478
+ Encode a list of images into list of embeddings.
479
+
480
+ Args:
481
+ images: Iterator of image paths or single image path to embed
482
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
483
+ parallel:
484
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
485
+ If 0, use all available cores.
486
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
487
+
488
+ Returns:
489
+ List of embeddings, one per document
490
+ """
491
+ yield from self._embed_images(
492
+ model_name=self.model_name,
493
+ cache_dir=str(self.cache_dir),
494
+ images=images,
495
+ batch_size=batch_size,
496
+ parallel=parallel,
497
+ providers=self.providers,
498
+ cuda=self.cuda,
499
+ device_ids=self.device_ids,
500
+ local_files_only=self._local_files_only,
501
+ specific_model_path=self._specific_model_path,
502
+ extra_session_options=self._extra_session_options,
503
+ **kwargs,
504
+ )
505
+
506
+ @classmethod
507
+ def _get_text_worker_class(cls) -> Type[TextEmbeddingWorker[NumpyArray]]:
508
+ return ColModernVBERTTextEmbeddingWorker
509
+
510
+ @classmethod
511
+ def _get_image_worker_class(cls) -> Type[ImageEmbeddingWorker[NumpyArray]]:
512
+ return ColModernVBERTImageEmbeddingWorker
513
+
514
+
515
+ class ColModernVBERTTextEmbeddingWorker(TextEmbeddingWorker[NumpyArray]):
516
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT:
517
+ return ColModernVBERT(
518
+ model_name=model_name,
519
+ cache_dir=cache_dir,
520
+ threads=1,
521
+ **kwargs,
522
+ )
523
+
524
+
525
+ class ColModernVBERTImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
526
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> ColModernVBERT:
527
+ return ColModernVBERT(
528
+ model_name=model_name,
529
+ cache_dir=cache_dir,
530
+ threads=1,
531
+ **kwargs,
532
+ )